diff --git a/CLAUDE.md b/.claude/CLAUDE.md similarity index 86% rename from CLAUDE.md rename to .claude/CLAUDE.md index d07bd3f3e1..c1349e8c9d 100644 --- a/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -1,24 +1,29 @@ # Bitwarden Server - Claude Code Configuration +## Project Context Files + +**Read these files before reviewing to ensure that you fully understand the project and contributing guidelines** + +1. @README.md +2. @CONTRIBUTING.md +3. @.github/PULL_REQUEST_TEMPLATE.md + ## Critical Rules -- **NEVER** edit: `/bin/`, `/obj/`, `/.git/`, `/.vs/`, `/packages/` which are generated files - **NEVER** use code regions: If complexity suggests regions, refactor for better readability + - **NEVER** compromise zero-knowledge principles: User vault data must remain encrypted and inaccessible to Bitwarden + - **NEVER** log or expose sensitive data: No PII, passwords, keys, or vault data in logs or error messages + - **ALWAYS** use secure communication channels: Enforce confidentiality, integrity, and authenticity + - **ALWAYS** encrypt sensitive data: All vault data must be encrypted at rest, in transit, and in use + - **ALWAYS** prioritize cryptographic integrity and data protection + - **ALWAYS** add unit tests (with mocking) for any new feature development -## Project Context - -- **Architecture**: Feature and team-based organization -- **Framework**: .NET 8.0, ASP.NET Core -- **Database**: SQL Server primary, EF Core supports PostgreSQL, MySQL/MariaDB, SQLite -- **Testing**: xUnit, NSubstitute -- **Container**: Docker, Docker Compose, Kubernetes/Helm deployable - ## Project Structure - **Source Code**: `/src/` - Services and core infrastructure @@ -42,7 +47,7 @@ - **Database update**: `pwsh dev/migrate.ps1` - **Generate OpenAPI**: `pwsh dev/generate_openapi_files.ps1` -## Code Review Checklist +## Development Workflow - Security impact assessed - xUnit tests added / updated diff --git a/.editorconfig b/.editorconfig index 21d7ac4a3a..fd68808456 100644 --- a/.editorconfig +++ b/.editorconfig @@ -123,3 +123,12 @@ csharp_style_namespace_declarations = file_scoped:warning # Switch expression dotnet_diagnostic.CS8509.severity = error # missing switch case for named enum value dotnet_diagnostic.CS8524.severity = none # missing switch case for unnamed enum value + +# CA2253: Named placeholders should nto be numeric values +dotnet_diagnostic.CA2253.severity = suggestion + +# CA2254: Template should be a static expression +dotnet_diagnostic.CA2254.severity = warning + +# CA1727: Use PascalCase for named placeholders +dotnet_diagnostic.CA1727.severity = suggestion diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6db4905fec..f0c85d98c1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,6 +36,7 @@ util/Setup/** @bitwarden/dept-bre @bitwarden/team-platform-dev # UIF src/Core/MailTemplates/Mjml @bitwarden/team-ui-foundation # Teams are expected to own sub-directories of this project +src/Core/MailTemplates/Mjml/.mjmlconfig # This change allows teams to add components within their own subdirectories without requiring a code review from UIF. # Auth team **/Auth @bitwarden/team-auth-dev @@ -52,6 +53,11 @@ src/Core/IdentityServer @bitwarden/team-auth-dev # Dirt (Data Insights & Reporting) team **/Dirt @bitwarden/team-data-insights-and-reporting-dev +src/Events @bitwarden/team-data-insights-and-reporting-dev +src/EventsProcessor @bitwarden/team-data-insights-and-reporting-dev +test/Events.IntegrationTest @bitwarden/team-data-insights-and-reporting-dev +test/Events.Test @bitwarden/team-data-insights-and-reporting-dev +test/EventsProcessor.Test @bitwarden/team-data-insights-and-reporting-dev # Vault team **/Vault @bitwarden/team-vault-dev @@ -62,8 +68,6 @@ src/Core/IdentityServer @bitwarden/team-auth-dev bitwarden_license/src/Scim @bitwarden/team-admin-console-dev bitwarden_license/src/test/Scim.IntegrationTest @bitwarden/team-admin-console-dev bitwarden_license/src/test/Scim.ScimTest @bitwarden/team-admin-console-dev -src/Events @bitwarden/team-admin-console-dev -src/EventsProcessor @bitwarden/team-admin-console-dev # Billing team **/*billing* @bitwarden/team-billing-dev @@ -96,6 +100,14 @@ src/Admin/Views/Tools @bitwarden/team-billing-dev # The PushType enum is expected to be editted by anyone without need for Platform review src/Core/Platform/Push/PushType.cs +# SDK +util/RustSdk @bitwarden/team-sdk-sme + # Multiple owners - DO NOT REMOVE (BRE) **/packages.lock.json Directory.Build.props + +# Claude related files +.claude/ @bitwarden/team-ai-sme +.github/workflows/respond.yml @bitwarden/team-ai-sme +.github/workflows/review-code.yml @bitwarden/team-ai-sme diff --git a/.github/ISSUE_TEMPLATE/bw-unified.yml b/.github/ISSUE_TEMPLATE/bw-lite.yml similarity index 88% rename from .github/ISSUE_TEMPLATE/bw-unified.yml rename to .github/ISSUE_TEMPLATE/bw-lite.yml index 240b1faa72..0c43fa5835 100644 --- a/.github/ISSUE_TEMPLATE/bw-unified.yml +++ b/.github/ISSUE_TEMPLATE/bw-lite.yml @@ -1,6 +1,6 @@ -name: Bitwarden Unified Deployment Bug Report +name: Bitwarden lite Deployment Bug Report description: File a bug report -labels: [bug, bw-unified-deploy] +labels: [bug, bw-lite-deploy] body: - type: markdown attributes: @@ -70,15 +70,6 @@ body: mariadb:10 # Postgres Example postgres:14 - - type: textarea - id: epic-label - attributes: - label: Issue-Link - description: Link to our pinned issue, tracking all Bitwarden Unified - value: | - https://github.com/bitwarden/server/issues/2480 - validations: - required: true - type: checkboxes id: issue-tracking-info attributes: diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 5c01832c06..77539ef839 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -2,6 +2,7 @@ $schema: "https://docs.renovatebot.com/renovate-schema.json", extends: ["github>bitwarden/renovate-config"], // Extends our default configuration for pinned dependencies enabledManagers: [ + "cargo", "dockerfile", "docker-compose", "github-actions", @@ -9,32 +10,7 @@ "nuget", ], packageRules: [ - { - groupName: "dockerfile minor", - matchManagers: ["dockerfile"], - matchUpdateTypes: ["minor"], - }, - { - groupName: "docker-compose minor", - matchManagers: ["docker-compose"], - matchUpdateTypes: ["minor"], - }, - { - groupName: "github-action minor", - matchManagers: ["github-actions"], - matchUpdateTypes: ["minor"], - addLabels: ["hold"], - }, - { - // For any Microsoft.Extensions.* and Microsoft.AspNetCore.* packages, we want to create PRs for patch updates. - // This overrides the default that ignores patch updates for nuget dependencies. - matchPackageNames: [ - "/^Microsoft\\.Extensions\\./", - "/^Microsoft\\.AspNetCore\\./", - ], - matchUpdateTypes: ["patch"], - dependencyDashboardApproval: false, - }, + // ==================== Team Ownership Rules ==================== { matchManagers: ["dockerfile", "docker-compose"], commitMessagePrefix: "[deps] BRE:", @@ -53,11 +29,11 @@ }, { matchPackageNames: [ - "Azure.Extensions.AspNetCore.DataProtection.Blobs", "DuoUniversal", "Fido2.AspNet", "Duende.IdentityServer", "Microsoft.AspNetCore.Authentication.JwtBearer", + "Microsoft.Extensions.Caching.Cosmos", "Microsoft.Extensions.Identity.Stores", "Otp.NET", "Sustainsys.Saml2.AspNetCore2", @@ -80,11 +56,7 @@ "Microsoft.AspNetCore.Mvc.Testing", "Newtonsoft.Json", "NSubstitute", - "Sentry.Serilog", - "Serilog.AspNetCore", - "Serilog.Extensions.Logging", "Serilog.Extensions.Logging.File", - "Serilog.Sinks.SyslogMessages", "Stripe.net", "Swashbuckle.AspNetCore", "Swashbuckle.AspNetCore.SwaggerGen", @@ -95,11 +67,6 @@ commitMessagePrefix: "[deps] Billing:", reviewers: ["team:team-billing-dev"], }, - { - matchPackageNames: ["/^Microsoft\\.EntityFrameworkCore\\./", "/^dotnet-ef/"], - groupName: "EntityFrameworkCore", - description: "Group EntityFrameworkCore to exclude them from the dotnet monorepo preset", - }, { matchPackageNames: [ "Dapper", @@ -131,6 +98,7 @@ "AspNetCoreRateLimit", "AspNetCoreRateLimit.Redis", "Azure.Data.Tables", + "Azure.Extensions.AspNetCore.DataProtection.Blobs", "Azure.Messaging.EventGrid", "Azure.Messaging.ServiceBus", "Azure.Storage.Blobs", @@ -146,7 +114,6 @@ "Microsoft.Extensions.DependencyInjection", "Microsoft.Extensions.Logging", "Microsoft.Extensions.Logging.Console", - "Microsoft.Extensions.Caching.Cosmos", "Microsoft.Extensions.Caching.SqlServer", "Microsoft.Extensions.Caching.StackExchangeRedis", "Quartz", @@ -155,6 +122,12 @@ commitMessagePrefix: "[deps] Platform:", reviewers: ["team:team-platform-dev"], }, + { + matchUpdateTypes: ["lockFileMaintenance"], + description: "Platform owns lock file maintenance", + commitMessagePrefix: "[deps] Platform:", + reviewers: ["team:team-platform-dev"], + }, { matchPackageNames: [ "AutoMapper.Extensions.Microsoft.DependencyInjection", @@ -184,6 +157,73 @@ commitMessagePrefix: "[deps] Vault:", reviewers: ["team:team-vault-dev"], }, + + // ==================== Grouping Rules ==================== + // These come after any specific team assignment rules to ensure + // that grouping is not overridden by subsequent rule definitions. + { + groupName: "cargo minor", + matchManagers: ["cargo"], + matchUpdateTypes: ["minor"], + }, + { + groupName: "dockerfile minor", + matchManagers: ["dockerfile"], + matchUpdateTypes: ["minor"], + }, + { + groupName: "docker-compose minor", + matchManagers: ["docker-compose"], + matchUpdateTypes: ["minor"], + }, + { + groupName: "github-action minor", + matchManagers: ["github-actions"], + matchUpdateTypes: ["minor"], + addLabels: ["hold"], + }, + { + matchPackageNames: ["/^Microsoft\\.EntityFrameworkCore\\./", "/^dotnet-ef/"], + groupName: "EntityFrameworkCore", + description: "Group EntityFrameworkCore to exclude them from the dotnet monorepo preset", + }, + { + matchPackageNames: ["https://github.com/bitwarden/sdk-internal.git"], + groupName: "sdk-internal", + dependencyDashboardApproval: true + }, + + // ==================== Dashboard Rules ==================== + { + // For any Microsoft.Extensions.* and Microsoft.AspNetCore.* packages, we want to create PRs for patch updates. + // This overrides the default that ignores patch updates for nuget dependencies. + matchPackageNames: [ + "/^Microsoft\\.Extensions\\./", + "/^Microsoft\\.AspNetCore\\./", + ], + matchUpdateTypes: ["patch"], + dependencyDashboardApproval: false, + }, + { + // For the Platform-owned dependencies below, we have decided we will only be creating PRs + // for major updates, and sending minor (as well as patch, inherited from base config) to the dashboard. + // This rule comes AFTER grouping rules so that groups are respected while still + // sending minor/patch updates to the dependency dashboard for approval. + matchPackageNames: [ + "AspNetCoreRateLimit", + "AspNetCoreRateLimit.Redis", + "Azure.Data.Tables", + "Azure.Extensions.AspNetCore.DataProtection.Blobs", + "Azure.Messaging.EventGrid", + "Azure.Messaging.ServiceBus", + "Azure.Storage.Blobs", + "Azure.Storage.Queues", + "LaunchDarkly.ServerSdk", + "Quartz", + ], + matchUpdateTypes: ["minor"], + dependencyDashboardApproval: true, + }, ], ignoreDeps: ["dotnet-sdk"], } diff --git a/.github/workflows/_move_edd_db_scripts.yml b/.github/workflows/_move_edd_db_scripts.yml index b38a3e0dff..742e7b897e 100644 --- a/.github/workflows/_move_edd_db_scripts.yml +++ b/.github/workflows/_move_edd_db_scripts.yml @@ -38,21 +38,22 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Check out branch - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: token: ${{ steps.retrieve-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + persist-credentials: false - name: Get script prefix id: prefix - run: echo "prefix=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT + run: echo "prefix=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" - name: Check if any files in DB transition or finalization directories id: check-script-existence run: | if [ -f util/Migrator/DbScripts_transition/* -o -f util/Migrator/DbScripts_finalization/* ]; then - echo "copy_edd_scripts=true" >> $GITHUB_OUTPUT + echo "copy_edd_scripts=true" >> "$GITHUB_OUTPUT" else - echo "copy_edd_scripts=false" >> $GITHUB_OUTPUT + echo "copy_edd_scripts=false" >> "$GITHUB_OUTPUT" fi move-scripts: @@ -67,20 +68,21 @@ jobs: if: ${{ needs.setup.outputs.copy_edd_scripts == 'true' }} steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 0 + persist-credentials: true - name: Generate branch name id: branch_name env: PREFIX: ${{ needs.setup.outputs.migration_filename_prefix }} - run: echo "branch_name=move_edd_db_scripts_$PREFIX" >> $GITHUB_OUTPUT + run: echo "branch_name=move_edd_db_scripts_$PREFIX" >> "$GITHUB_OUTPUT" - name: "Create branch" env: BRANCH: ${{ steps.branch_name.outputs.branch_name }} - run: git switch -c $BRANCH + run: git switch -c "$BRANCH" - name: Move scripts and finalization database schema id: move-files @@ -120,7 +122,7 @@ jobs: # sync finalization schema back to dbo, maintaining structure rsync -r "$src_dir/" "$dest_dir/" - rm -rf $src_dir/* + rm -rf "${src_dir}"/* # Replace any finalization references due to the move find ./src/Sql/dbo -name "*.sql" -type f -exec sed -i \ @@ -131,7 +133,7 @@ jobs: moved_files="$moved_files \n $file" done - echo "moved_files=$moved_files" >> $GITHUB_OUTPUT + echo "moved_files=$moved_files" >> "$GITHUB_OUTPUT" - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -162,18 +164,20 @@ jobs: - name: Commit and push changes id: commit + env: + BRANCH_NAME: ${{ steps.branch_name.outputs.branch_name }} run: | git config --local user.email "106330231+bitwarden-devops-bot@users.noreply.github.com" git config --local user.name "bitwarden-devops-bot" if [ -n "$(git status --porcelain)" ]; then git add . git commit -m "Move EDD database scripts" -a - git push -u origin ${{ steps.branch_name.outputs.branch_name }} - echo "pr_needed=true" >> $GITHUB_OUTPUT + git push -u origin "${BRANCH_NAME}" + echo "pr_needed=true" >> "$GITHUB_OUTPUT" else echo "No changes to commit!"; - echo "pr_needed=false" >> $GITHUB_OUTPUT - echo "### :mega: No changes to commit! PR was ommited." >> $GITHUB_STEP_SUMMARY + echo "pr_needed=false" >> "$GITHUB_OUTPUT" + echo "### :mega: No changes to commit! PR was ommited." >> "$GITHUB_STEP_SUMMARY" fi - name: Create PR for ${{ steps.branch_name.outputs.branch_name }} @@ -195,7 +199,7 @@ jobs: Files moved: $(echo -e "$MOVED_FILES") ") - echo "pr_url=${PR_URL}" >> $GITHUB_OUTPUT + echo "pr_url=${PR_URL}" >> "$GITHUB_OUTPUT" - name: Notify Slack about creation of PR if: ${{ steps.commit.outputs.pr_needed == 'true' }} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe82f9fbe6..694e9048a7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,22 +22,23 @@ env: jobs: lint: name: Lint - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Verify format run: dotnet format --verify-no-changes build-artifacts: name: Build Docker images - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - lint outputs: @@ -45,6 +46,7 @@ jobs: permissions: security-events: write id-token: write + timeout-minutes: 45 strategy: fail-fast: false matrix: @@ -97,30 +99,31 @@ jobs: id: check-secrets run: | has_secrets=${{ secrets.AZURE_CLIENT_ID != '' }} - echo "has_secrets=$has_secrets" >> $GITHUB_OUTPUT + echo "has_secrets=$has_secrets" >> "$GITHUB_OUTPUT" - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Check branch to publish env: PUBLISH_BRANCHES: "main,rc,hotfix-rc" id: publish-branch-check run: | - IFS="," read -a publish_branches <<< $PUBLISH_BRANCHES + IFS="," read -a publish_branches <<< "$PUBLISH_BRANCHES" if [[ " ${publish_branches[*]} " =~ " ${GITHUB_REF:11} " ]]; then - echo "is_publish_branch=true" >> $GITHUB_ENV + echo "is_publish_branch=true" >> "$GITHUB_ENV" else - echo "is_publish_branch=false" >> $GITHUB_ENV + echo "is_publish_branch=false" >> "$GITHUB_ENV" fi - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Set up Node - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 with: cache: "npm" cache-dependency-path: "**/package-lock.json" @@ -157,7 +160,7 @@ jobs: ls -atlh ../../../ - name: Upload project artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 if: ${{ matrix.dotnet }} with: name: ${{ matrix.project_name }}.zip @@ -166,10 +169,10 @@ jobs: ########## Set up Docker ########## - name: Set up QEMU emulators - uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0 + uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 + uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0 ########## ACRs ########## - name: Log in to Azure @@ -182,13 +185,6 @@ jobs: - name: Log in to ACR - production subscription run: az acr login -n bitwardenprod - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat - uses: bitwarden/gh-actions/get-keyvault-secrets@main - with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" - ########## Generate image tag and build Docker image ########## - name: Generate Docker image tag id: tag @@ -209,8 +205,8 @@ jobs: IMAGE_TAG=dev fi - echo "image_tag=$IMAGE_TAG" >> $GITHUB_OUTPUT - echo "### :mega: Docker Image Tag: $IMAGE_TAG" >> $GITHUB_STEP_SUMMARY + echo "image_tag=$IMAGE_TAG" >> "$GITHUB_OUTPUT" + echo "### :mega: Docker Image Tag: $IMAGE_TAG" >> "$GITHUB_STEP_SUMMARY" - name: Set up project name id: setup @@ -218,7 +214,7 @@ jobs: PROJECT_NAME=$(echo "${{ matrix.project_name }}" | awk '{print tolower($0)}') echo "Matrix name: ${{ matrix.project_name }}" echo "PROJECT_NAME: $PROJECT_NAME" - echo "project_name=$PROJECT_NAME" >> $GITHUB_OUTPUT + echo "project_name=$PROJECT_NAME" >> "$GITHUB_OUTPUT" - name: Generate image tags(s) id: image-tags @@ -228,12 +224,12 @@ jobs: SHA: ${{ github.sha }} run: | TAGS="${_AZ_REGISTRY}/${PROJECT_NAME}:${IMAGE_TAG}" - echo "primary_tag=$TAGS" >> $GITHUB_OUTPUT + echo "primary_tag=$TAGS" >> "$GITHUB_OUTPUT" if [[ "${IMAGE_TAG}" == "dev" ]]; then - SHORT_SHA=$(git rev-parse --short ${SHA}) + SHORT_SHA=$(git rev-parse --short "${SHA}") TAGS=$TAGS",${_AZ_REGISTRY}/${PROJECT_NAME}:dev-${SHORT_SHA}" fi - echo "tags=$TAGS" >> $GITHUB_OUTPUT + echo "tags=$TAGS" >> "$GITHUB_OUTPUT" - name: Build Docker image id: build-artifacts @@ -247,12 +243,10 @@ jobs: linux/arm64 push: true tags: ${{ steps.image-tags.outputs.tags }} - secrets: | - "GH_PAT=${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }}" - name: Install Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' - uses: sigstore/cosign-installer@3454372f43399081ed03b604cb2d021dabca52bb # v3.8.2 + uses: sigstore/cosign-installer@7e8b541eb2e61bf99390e1afd4be13a184e9ebc5 # v3.10.1 - name: Sign image with Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' @@ -260,23 +254,24 @@ jobs: DIGEST: ${{ steps.build-artifacts.outputs.digest }} TAGS: ${{ steps.image-tags.outputs.tags }} run: | - IFS="," read -a tags <<< "${TAGS}" - images="" - for tag in "${tags[@]}"; do - images+="${tag}@${DIGEST} " + IFS=',' read -r -a tags_array <<< "${TAGS}" + images=() + for tag in "${tags_array[@]}"; do + images+=("${tag}@${DIGEST}") done - cosign sign --yes ${images} + cosign sign --yes ${images[@]} + echo "images=${images[*]}" >> "$GITHUB_OUTPUT" - name: Scan Docker image id: container-scan - uses: anchore/scan-action@2c901ab7378897c01b8efaa2d0c9bf519cc64b9e # v6.2.0 + uses: anchore/scan-action@3c9a191a0fbab285ca6b8530b5de5a642cba332f # v7.2.2 with: image: ${{ steps.image-tags.outputs.primary_tag }} fail-build: false output-format: sarif - name: Upload Grype results to GitHub - uses: github/codeql-action/upload-sarif@dd746615b3b9d728a6a37ca2045b68ca76d4841a # v3.28.8 + uses: github/codeql-action/upload-sarif@e12f0178983d466f2f6028f5cc7a6d786fd97f4b # v4.31.4 with: sarif_file: ${{ steps.container-scan.outputs.sarif }} sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }} @@ -287,19 +282,20 @@ jobs: upload: name: Upload - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: build-artifacts permissions: id-token: write actions: read steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -309,7 +305,7 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to ACR - production subscription - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors - name: Make Docker stubs if: | @@ -332,26 +328,26 @@ jobs: STUB_OUTPUT=$(pwd)/docker-stub # Run setup - docker run -i --rm --name setup -v $STUB_OUTPUT/US:/bitwarden $SETUP_IMAGE \ + docker run -i --rm --name setup -v "$STUB_OUTPUT/US:/bitwarden" "$SETUP_IMAGE" \ /app/Setup -stub 1 -install 1 -domain bitwarden.example.com -os lin -cloud-region US - docker run -i --rm --name setup -v $STUB_OUTPUT/EU:/bitwarden $SETUP_IMAGE \ + docker run -i --rm --name setup -v "$STUB_OUTPUT/EU:/bitwarden" "$SETUP_IMAGE" \ /app/Setup -stub 1 -install 1 -domain bitwarden.example.com -os lin -cloud-region EU - sudo chown -R $(whoami):$(whoami) $STUB_OUTPUT + sudo chown -R "$(whoami):$(whoami)" "$STUB_OUTPUT" # Remove extra directories and files - rm -rf $STUB_OUTPUT/US/letsencrypt - rm -rf $STUB_OUTPUT/EU/letsencrypt - rm $STUB_OUTPUT/US/env/uid.env $STUB_OUTPUT/US/config.yml - rm $STUB_OUTPUT/EU/env/uid.env $STUB_OUTPUT/EU/config.yml + rm -rf "$STUB_OUTPUT/US/letsencrypt" + rm -rf "$STUB_OUTPUT/EU/letsencrypt" + rm "$STUB_OUTPUT/US/env/uid.env" "$STUB_OUTPUT/US/config.yml" + rm "$STUB_OUTPUT/EU/env/uid.env" "$STUB_OUTPUT/EU/config.yml" # Create uid environment files - touch $STUB_OUTPUT/US/env/uid.env - touch $STUB_OUTPUT/EU/env/uid.env + touch "$STUB_OUTPUT/US/env/uid.env" + touch "$STUB_OUTPUT/EU/env/uid.env" # Zip up the Docker stub files - cd docker-stub/US; zip -r ../../docker-stub-US.zip *; cd ../.. - cd docker-stub/EU; zip -r ../../docker-stub-EU.zip *; cd ../.. + cd docker-stub/US; zip -r ../../docker-stub-US.zip ./*; cd ../.. + cd docker-stub/EU; zip -r ../../docker-stub-EU.zip ./*; cd ../.. - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main @@ -360,7 +356,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: docker-stub-US.zip path: docker-stub-US.zip @@ -370,7 +366,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: docker-stub-EU.zip path: docker-stub-EU.zip @@ -382,21 +378,21 @@ jobs: pwsh ./generate_openapi_files.ps1 - name: Upload Public API Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: swagger.json path: api.public.json if-no-files-found: error - name: Upload Internal API Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: internal.json path: api.json if-no-files-found: error - name: Upload Identity Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: identity.json path: identity.json @@ -404,7 +400,7 @@ jobs: build-mssqlmigratorutility: name: Build MSSQL migrator utility - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - lint defaults: @@ -420,12 +416,13 @@ jobs: - win-x64 steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Print environment run: | @@ -441,7 +438,7 @@ jobs: - name: Upload project artifact for Windows if: ${{ contains(matrix.target, 'win') == true }} - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: MsSqlMigratorUtility-${{ matrix.target }} path: util/MsSqlMigratorUtility/obj/build-output/publish/MsSqlMigratorUtility.exe @@ -449,7 +446,7 @@ jobs: - name: Upload project artifact if: ${{ contains(matrix.target, 'win') == false }} - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: MsSqlMigratorUtility-${{ matrix.target }} path: util/MsSqlMigratorUtility/obj/build-output/publish/MsSqlMigratorUtility @@ -460,7 +457,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - runs-on: ubuntu-24.04 + runs-on: ubuntu-22.04 needs: - build-artifacts permissions: @@ -473,25 +470,34 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main - - name: Trigger self-host build + - name: Generate GH App token + uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + owner: ${{ github.repository_owner }} + repositories: self-host + + - name: Trigger Bitwarden lite build uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', repo: 'self-host', - workflow_id: 'build-unified.yml', + workflow_id: 'build-bitwarden-lite.yml', ref: 'main', inputs: { server_branch: process.env.GITHUB_REF @@ -514,20 +520,29 @@ jobs: tenant_id: ${{ secrets.AZURE_TENANT_ID }} client_id: ${{ secrets.AZURE_CLIENT_ID }} - - name: Retrieve GitHub PAT secrets - id: retrieve-secret-pat + - name: Get Azure Key Vault secrets + id: get-kv-secrets uses: bitwarden/gh-actions/get-keyvault-secrets@main with: - keyvault: "bitwarden-ci" - secrets: "github-pat-bitwarden-devops-bot-repo-scope" + keyvault: gh-org-bitwarden + secrets: "BW-GHAPP-ID,BW-GHAPP-KEY" - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main + - name: Generate GH App token + uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 + id: app-token + with: + app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} + private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + owner: ${{ github.repository_owner }} + repositories: devops + - name: Trigger k8s deploy uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: - github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + github-token: ${{ steps.app-token.outputs.token }} script: | await github.rest.actions.createWorkflowDispatch({ owner: 'bitwarden', diff --git a/.github/workflows/cleanup-after-pr.yml b/.github/workflows/cleanup-after-pr.yml index e39bf8ea3a..4e59f1fa96 100644 --- a/.github/workflows/cleanup-after-pr.yml +++ b/.github/workflows/cleanup-after-pr.yml @@ -22,7 +22,7 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to Azure ACR - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors ########## Remove Docker images ########## - name: Remove the Docker image from ACR @@ -45,20 +45,20 @@ jobs: - Setup - Sso run: | - for SERVICE in $(echo "${{ env.SERVICES }}" | yq e ".services[]" - ) + for SERVICE in $(echo "${SERVICES}" | yq e ".services[]" - ) do - SERVICE_NAME=$(echo $SERVICE | awk '{print tolower($0)}') + SERVICE_NAME=$(echo "$SERVICE" | awk '{print tolower($0)}') IMAGE_TAG=$(echo "${REF}" | sed "s#/#-#g") # slash safe branch name echo "[*] Checking if remote exists: $_AZ_REGISTRY/$SERVICE_NAME:$IMAGE_TAG" TAG_EXISTS=$( - az acr repository show-tags --name $_AZ_REGISTRY --repository $SERVICE_NAME \ - | jq --arg $TAG "$IMAGE_TAG" -e '. | any(. == "$TAG")' + az acr repository show-tags --name "$_AZ_REGISTRY" --repository "$SERVICE_NAME" \ + | jq --arg TAG "$IMAGE_TAG" -e '. | any(. == $TAG)' ) if [[ "$TAG_EXISTS" == "true" ]]; then echo "[*] Tag exists. Removing tag" - az acr repository delete --name $_AZ_REGISTRY --image $SERVICE_NAME:$IMAGE_TAG --yes + az acr repository delete --name "$_AZ_REGISTRY" --image "$SERVICE_NAME:$IMAGE_TAG" --yes else echo "[*] Tag does not exist. No action needed" fi diff --git a/.github/workflows/cleanup-rc-branch.yml b/.github/workflows/cleanup-rc-branch.yml index 5c74284423..ae482ef4e6 100644 --- a/.github/workflows/cleanup-rc-branch.yml +++ b/.github/workflows/cleanup-rc-branch.yml @@ -31,10 +31,12 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Checkout main - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: main token: ${{ steps.retrieve-bot-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + persist-credentials: false + fetch-depth: 0 - name: Check if a RC branch exists id: branch-check @@ -43,11 +45,11 @@ jobs: rc_branch_check=$(git ls-remote --heads origin rc | wc -l) if [[ "${hotfix_rc_branch_check}" -gt 0 ]]; then - echo "hotfix-rc branch exists." | tee -a $GITHUB_STEP_SUMMARY - echo "name=hotfix-rc" >> $GITHUB_OUTPUT + echo "hotfix-rc branch exists." | tee -a "$GITHUB_STEP_SUMMARY" + echo "name=hotfix-rc" >> "$GITHUB_OUTPUT" elif [[ "${rc_branch_check}" -gt 0 ]]; then - echo "rc branch exists." | tee -a $GITHUB_STEP_SUMMARY - echo "name=rc" >> $GITHUB_OUTPUT + echo "rc branch exists." | tee -a "$GITHUB_STEP_SUMMARY" + echo "name=rc" >> "$GITHUB_OUTPUT" fi - name: Delete RC branch @@ -55,6 +57,6 @@ jobs: BRANCH_NAME: ${{ steps.branch-check.outputs.name }} run: | if ! [[ -z "$BRANCH_NAME" ]]; then - git push --quiet origin --delete $BRANCH_NAME - echo "Deleted $BRANCH_NAME branch." | tee -a $GITHUB_STEP_SUMMARY + git push --quiet origin --delete "$BRANCH_NAME" + echo "Deleted $BRANCH_NAME branch." | tee -a "$GITHUB_STEP_SUMMARY" fi diff --git a/.github/workflows/code-references.yml b/.github/workflows/code-references.yml index 75e0c43306..cb7ca9e200 100644 --- a/.github/workflows/code-references.yml +++ b/.github/workflows/code-references.yml @@ -19,9 +19,9 @@ jobs: id: check-secret-access run: | if [ "${{ secrets.AZURE_CLIENT_ID }}" != '' ]; then - echo "available=true" >> $GITHUB_OUTPUT; + echo "available=true" >> "$GITHUB_OUTPUT"; else - echo "available=false" >> $GITHUB_OUTPUT; + echo "available=false" >> "$GITHUB_OUTPUT"; fi refs: @@ -36,7 +36,9 @@ jobs: steps: - name: Check out repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -57,7 +59,7 @@ jobs: - name: Collect id: collect - uses: launchdarkly/find-code-references@e3e9da201b87ada54eb4c550c14fb783385c5c8a # v2.13.0 + uses: launchdarkly/find-code-references@89a7d362d1d4b3725fe0fe0ccd0dc69e3bdcba58 # v2.14.0 with: accessToken: ${{ steps.get-kv-secrets.outputs.LD-ACCESS-TOKEN }} projKey: default @@ -65,14 +67,14 @@ jobs: - name: Add label if: steps.collect.outputs.any-changed == 'true' - run: gh pr edit $PR_NUMBER --add-label feature-flag + run: gh pr edit "$PR_NUMBER" --add-label feature-flag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} - name: Remove label if: steps.collect.outputs.any-changed == 'false' - run: gh pr edit $PR_NUMBER --remove-label feature-flag + run: gh pr edit "$PR_NUMBER" --remove-label feature-flag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} diff --git a/.github/workflows/enforce-labels.yml b/.github/workflows/enforce-labels.yml index 353127c751..1759b29787 100644 --- a/.github/workflows/enforce-labels.yml +++ b/.github/workflows/enforce-labels.yml @@ -17,5 +17,5 @@ jobs: - name: Check for label run: | echo "PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" - echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> $GITHUB_STEP_SUMMARY + echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> "$GITHUB_STEP_SUMMARY" exit 1 diff --git a/.github/workflows/ephemeral-environment.yml b/.github/workflows/ephemeral-environment.yml index d85fcf2fd4..456ca573cc 100644 --- a/.github/workflows/ephemeral-environment.yml +++ b/.github/workflows/ephemeral-environment.yml @@ -16,5 +16,5 @@ jobs: with: project: server pull_request_number: ${{ github.event.number }} - sync_environment: true + sync_environment: false secrets: inherit diff --git a/.github/workflows/load-test.yml b/.github/workflows/load-test.yml index 9bc6da89e7..10bfe50d10 100644 --- a/.github/workflows/load-test.yml +++ b/.github/workflows/load-test.yml @@ -63,13 +63,15 @@ jobs: # Datadog agent for collecting OTEL metrics from k6 - name: Start Datadog agent + env: + DD_API_KEY: ${{ steps.get-kv-secrets.outputs.DD-API-KEY }} run: | docker run --detach \ --name datadog-agent \ -p 4317:4317 \ -p 5555:5555 \ -e DD_SITE=us3.datadoghq.com \ - -e DD_API_KEY=${{ steps.get-kv-secrets.outputs.DD-API-KEY }} \ + -e DD_API_KEY="${DD_API_KEY}" \ -e DD_DOGSTATSD_NON_LOCAL_TRAFFIC=1 \ -e DD_OTLP_CONFIG_RECEIVER_PROTOCOLS_GRPC_ENDPOINT=0.0.0.0:4317 \ -e DD_HEALTH_PORT=5555 \ @@ -85,7 +87,7 @@ jobs: datadog/agent:7-full@sha256:7ea933dec3b8baa8c19683b1c3f6f801dbf3291f748d9ed59234accdaac4e479 - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: persist-credentials: false @@ -93,7 +95,7 @@ jobs: uses: grafana/setup-k6-action@ffe7d7290dfa715e48c2ccc924d068444c94bde2 # v1.1.0 - name: Run k6 tests - uses: grafana/run-k6-action@c6b79182b9b666aa4f630f4a6be9158ead62536e # v1.2.0 + uses: grafana/run-k6-action@a15e2072ede004e8d46141e33d7f7dad8ad08d9d # v1.3.1 continue-on-error: false env: K6_OTEL_METRIC_PREFIX: k6_ diff --git a/.github/workflows/protect-files.yml b/.github/workflows/protect-files.yml index 546b8344a6..4b137eb221 100644 --- a/.github/workflows/protect-files.yml +++ b/.github/workflows/protect-files.yml @@ -31,9 +31,10 @@ jobs: label: "DB-migrations-changed" steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 2 + persist-credentials: false - name: Check for file changes id: check-changes @@ -43,9 +44,9 @@ jobs: for file in $MODIFIED_FILES do if [[ $file == *"${{ matrix.path }}"* ]]; then - echo "changes_detected=true" >> $GITHUB_OUTPUT + echo "changes_detected=true" >> "$GITHUB_OUTPUT" break - else echo "changes_detected=false" >> $GITHUB_OUTPUT + else echo "changes_detected=false" >> "$GITHUB_OUTPUT" fi done diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 444c2289d1..7983bef2bc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -36,21 +36,23 @@ jobs: steps: - name: Version output id: version-output + env: + INPUT_VERSION: ${{ inputs.version }} run: | - if [[ "${{ inputs.version }}" == "latest" || "${{ inputs.version }}" == "" ]]; then + if [[ "${INPUT_VERSION}" == "latest" || "${INPUT_VERSION}" == "" ]]; then VERSION=$(curl "https://api.github.com/repos/bitwarden/server/releases" | jq -c '.[] | select(.tag_name) | .tag_name' | head -1 | grep -ohE '20[0-9]{2}\.([1-9]|1[0-2])\.[0-9]+') echo "Latest Released Version: $VERSION" - echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "version=$VERSION" >> "$GITHUB_OUTPUT" else - echo "Release Version: ${{ inputs.version }}" - echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + echo "Release Version: ${INPUT_VERSION}" + echo "version=${INPUT_VERSION}" >> "$GITHUB_OUTPUT" fi - name: Get branch name id: branch run: | - BRANCH_NAME=$(basename ${{ github.ref }}) - echo "branch-name=$BRANCH_NAME" >> $GITHUB_OUTPUT + BRANCH_NAME=$(basename "${GITHUB_REF}") + echo "branch-name=$BRANCH_NAME" >> "$GITHUB_OUTPUT" - name: Create GitHub deployment uses: chrnorm/deployment-action@55729fcebec3d284f60f5bcabbd8376437d696b1 # v2.0.7 @@ -89,7 +91,6 @@ jobs: - project_name: Nginx - project_name: Notifications - project_name: Scim - - project_name: Server - project_name: Setup - project_name: Sso steps: @@ -104,7 +105,10 @@ jobs: echo "Github Release Option: $RELEASE_OPTION" - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 0 + persist-credentials: false - name: Set up project name id: setup @@ -112,7 +116,7 @@ jobs: PROJECT_NAME=$(echo "${{ matrix.project_name }}" | awk '{print tolower($0)}') echo "Matrix name: ${{ matrix.project_name }}" echo "PROJECT_NAME: $PROJECT_NAME" - echo "project_name=$PROJECT_NAME" >> $GITHUB_OUTPUT + echo "project_name=$PROJECT_NAME" >> "$GITHUB_OUTPUT" ########## ACR PROD ########## - name: Log in to Azure @@ -123,16 +127,16 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to Azure ACR - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors - name: Pull latest project image env: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker pull $_AZ_REGISTRY/$PROJECT_NAME:latest + docker pull "$_AZ_REGISTRY/$PROJECT_NAME:latest" else - docker pull $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME + docker pull "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" fi - name: Tag version and latest @@ -140,10 +144,10 @@ jobs: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker tag $_AZ_REGISTRY/$PROJECT_NAME:latest $_AZ_REGISTRY/$PROJECT_NAME:dryrun + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:latest" "$_AZ_REGISTRY/$PROJECT_NAME:dryrun" else - docker tag $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME $_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION - docker tag $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME $_AZ_REGISTRY/$PROJECT_NAME:latest + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" "$_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION" + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" "$_AZ_REGISTRY/$PROJECT_NAME:latest" fi - name: Push version and latest image @@ -151,10 +155,10 @@ jobs: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker push $_AZ_REGISTRY/$PROJECT_NAME:dryrun + docker push "$_AZ_REGISTRY/$PROJECT_NAME:dryrun" else - docker push $_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION - docker push $_AZ_REGISTRY/$PROJECT_NAME:latest + docker push "$_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION" + docker push "$_AZ_REGISTRY/$PROJECT_NAME:latest" fi - name: Log out of Docker diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8bb19b4da1..a3c4fb1ffd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,7 +39,10 @@ jobs: fi - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 0 + persist-credentials: false - name: Check release version id: version @@ -52,8 +55,8 @@ jobs: - name: Get branch name id: branch run: | - BRANCH_NAME=$(basename ${{ github.ref }}) - echo "branch-name=$BRANCH_NAME" >> $GITHUB_OUTPUT + BRANCH_NAME=$(basename "${GITHUB_REF}") + echo "branch-name=$BRANCH_NAME" >> "$GITHUB_OUTPUT" release: name: Create GitHub release @@ -86,7 +89,7 @@ jobs: - name: Create release if: ${{ inputs.release_type != 'Dry Run' }} - uses: ncipollo/release-action@440c8c1cb0ed28b9f43e4d1d670870f059653174 # v1.16.0 + uses: ncipollo/release-action@b7eabc95ff50cbeeedec83973935c8f306dfcd0b # v1.20.0 with: artifacts: "docker-stub-US.zip, docker-stub-EU.zip, diff --git a/.github/workflows/repository-management.yml b/.github/workflows/repository-management.yml index 67e1d8a926..c98faed340 100644 --- a/.github/workflows/repository-management.yml +++ b/.github/workflows/repository-management.yml @@ -22,9 +22,7 @@ on: required: false type: string -permissions: - pull-requests: write - contents: write +permissions: {} jobs: setup: @@ -32,6 +30,7 @@ jobs: runs-on: ubuntu-24.04 outputs: branch: ${{ steps.set-branch.outputs.branch }} + permissions: {} steps: - name: Set branch id: set-branch @@ -46,7 +45,7 @@ jobs: BRANCH="hotfix-rc" fi - echo "branch=$BRANCH" >> $GITHUB_OUTPUT + echo "branch=$BRANCH" >> "$GITHUB_OUTPUT" bump_version: name: Bump Version @@ -84,17 +83,19 @@ jobs: version: ${{ inputs.version_number_override }} - name: Generate GH App token - uses: actions/create-github-app-token@a8d616148505b5069dccd32f177bb87d7f39123b # v2.1.1 + uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 id: app-token with: app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + permission-contents: write - name: Check out branch - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: main token: ${{ steps.app-token.outputs.token }} + persist-credentials: true - name: Configure Git run: | @@ -110,7 +111,7 @@ jobs: id: current-version run: | CURRENT_VERSION=$(xmllint -xpath "/Project/PropertyGroup/Version/text()" Directory.Build.props) - echo "version=$CURRENT_VERSION" >> $GITHUB_OUTPUT + echo "version=$CURRENT_VERSION" >> "$GITHUB_OUTPUT" - name: Verify input version if: ${{ inputs.version_number_override != '' }} @@ -120,16 +121,15 @@ jobs: run: | # Error if version has not changed. if [[ "$NEW_VERSION" == "$CURRENT_VERSION" ]]; then - echo "Specified override version is the same as the current version." >> $GITHUB_STEP_SUMMARY + echo "Specified override version is the same as the current version." >> "$GITHUB_STEP_SUMMARY" exit 1 fi # Check if version is newer. - printf '%s\n' "${CURRENT_VERSION}" "${NEW_VERSION}" | sort -C -V - if [ $? -eq 0 ]; then + if printf '%s\n' "${CURRENT_VERSION}" "${NEW_VERSION}" | sort -C -V; then echo "Version is newer than the current version." else - echo "Version is older than the current version." >> $GITHUB_STEP_SUMMARY + echo "Version is older than the current version." >> "$GITHUB_STEP_SUMMARY" exit 1 fi @@ -160,15 +160,20 @@ jobs: id: set-final-version-output env: VERSION: ${{ inputs.version_number_override }} + BUMP_VERSION_OVERRIDE_OUTCOME: ${{ steps.bump-version-override.outcome }} + BUMP_VERSION_AUTOMATIC_OUTCOME: ${{ steps.bump-version-automatic.outcome }} + CALCULATE_NEXT_VERSION: ${{ steps.calculate-next-version.outputs.version }} run: | - if [[ "${{ steps.bump-version-override.outcome }}" = "success" ]]; then - echo "version=$VERSION" >> $GITHUB_OUTPUT - elif [[ "${{ steps.bump-version-automatic.outcome }}" = "success" ]]; then - echo "version=${{ steps.calculate-next-version.outputs.version }}" >> $GITHUB_OUTPUT + if [[ "${BUMP_VERSION_OVERRIDE_OUTCOME}" = "success" ]]; then + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + elif [[ "${BUMP_VERSION_AUTOMATIC_OUTCOME}" = "success" ]]; then + echo "version=${CALCULATE_NEXT_VERSION}" >> "$GITHUB_OUTPUT" fi - name: Commit files - run: git commit -m "Bumped version to ${{ steps.set-final-version-output.outputs.version }}" -a + env: + FINAL_VERSION: ${{ steps.set-final-version-output.outputs.version }} + run: git commit -m "Bumped version to $FINAL_VERSION" -a - name: Push changes run: git push @@ -202,24 +207,27 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Generate GH App token - uses: actions/create-github-app-token@a8d616148505b5069dccd32f177bb87d7f39123b # v2.1.1 + uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 id: app-token with: app-id: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-ID }} private-key: ${{ steps.get-kv-secrets.outputs.BW-GHAPP-KEY }} + permission-contents: write - name: Check out target ref - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ inputs.target_ref }} token: ${{ steps.app-token.outputs.token }} + persist-credentials: true + fetch-depth: 0 - name: Check if ${{ needs.setup.outputs.branch }} branch exists env: BRANCH_NAME: ${{ needs.setup.outputs.branch }} run: | - if [[ $(git ls-remote --heads origin $BRANCH_NAME) ]]; then - echo "$BRANCH_NAME already exists! Please delete $BRANCH_NAME before running again." >> $GITHUB_STEP_SUMMARY + if [[ $(git ls-remote --heads origin "$BRANCH_NAME") ]]; then + echo "$BRANCH_NAME already exists! Please delete $BRANCH_NAME before running again." >> "$GITHUB_STEP_SUMMARY" exit 1 fi @@ -227,16 +235,11 @@ jobs: env: BRANCH_NAME: ${{ needs.setup.outputs.branch }} run: | - git switch --quiet --create $BRANCH_NAME - git push --quiet --set-upstream origin $BRANCH_NAME + git switch --quiet --create "$BRANCH_NAME" + git push --quiet --set-upstream origin "$BRANCH_NAME" move_edd_db_scripts: name: Move EDD database scripts needs: cut_branch - permissions: - actions: read - contents: write - id-token: write - pull-requests: write + permissions: {} uses: ./.github/workflows/_move_edd_db_scripts.yml - secrets: inherit diff --git a/.github/workflows/respond.yml b/.github/workflows/respond.yml new file mode 100644 index 0000000000..d940ceee75 --- /dev/null +++ b/.github/workflows/respond.yml @@ -0,0 +1,28 @@ +name: Respond + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +permissions: {} + +jobs: + respond: + name: Respond + uses: bitwarden/gh-actions/.github/workflows/_respond.yml@main + secrets: + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + permissions: + actions: read + contents: write + id-token: write + issues: write + pull-requests: write diff --git a/.github/workflows/review-code.yml b/.github/workflows/review-code.yml index b49f5cec8f..908664209d 100644 --- a/.github/workflows/review-code.yml +++ b/.github/workflows/review-code.yml @@ -1,109 +1,21 @@ -name: Review code +name: Code Review on: pull_request: - types: [opened, synchronize, reopened] + types: [opened, labeled] permissions: {} jobs: review: name: Review - runs-on: ubuntu-24.04 + uses: bitwarden/gh-actions/.github/workflows/_review-code.yml@main + secrets: + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} permissions: + actions: read contents: read id-token: write pull-requests: write - - steps: - - name: Check out repo - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - fetch-depth: 0 - persist-credentials: false - - - name: Check for Vault team changes - id: check_changes - run: | - # Ensure we have the base branch - git fetch origin ${{ github.base_ref }} - - echo "Comparing changes between origin/${{ github.base_ref }} and HEAD" - CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}...HEAD) - - if [ -z "$CHANGED_FILES" ]; then - echo "Zero files changed" - echo "vault_team_changes=false" >> $GITHUB_OUTPUT - exit 0 - fi - - # Handle variations in spacing and multiple teams - VAULT_PATTERNS=$(grep -E "@bitwarden/team-vault-dev(\s|$)" .github/CODEOWNERS 2>/dev/null | awk '{print $1}') - - if [ -z "$VAULT_PATTERNS" ]; then - echo "⚠️ No patterns found for @bitwarden/team-vault-dev in CODEOWNERS" - echo "vault_team_changes=false" >> $GITHUB_OUTPUT - exit 0 - fi - - vault_team_changes=false - for pattern in $VAULT_PATTERNS; do - echo "Checking pattern: $pattern" - - # Handle **/directory patterns - if [[ "$pattern" == "**/"* ]]; then - # Remove the **/ prefix - dir_pattern="${pattern#\*\*/}" - # Check if any file contains this directory in its path - if echo "$CHANGED_FILES" | grep -qE "(^|/)${dir_pattern}(/|$)"; then - vault_team_changes=true - echo "✅ Found files matching pattern: $pattern" - echo "$CHANGED_FILES" | grep -E "(^|/)${dir_pattern}(/|$)" | sed 's/^/ - /' - break - fi - else - # Handle other patterns (shouldn't happen based on your CODEOWNERS) - if echo "$CHANGED_FILES" | grep -q "$pattern"; then - vault_team_changes=true - echo "✅ Found files matching pattern: $pattern" - echo "$CHANGED_FILES" | grep "$pattern" | sed 's/^/ - /' - break - fi - fi - done - - echo "vault_team_changes=$vault_team_changes" >> $GITHUB_OUTPUT - - if [ "$vault_team_changes" = "true" ]; then - echo "" - echo "✅ Vault team changes detected - proceeding with review" - else - echo "" - echo "❌ No Vault team changes detected - skipping review" - fi - - - name: Review with Claude Code - if: steps.check_changes.outputs.vault_team_changes == 'true' - uses: anthropics/claude-code-action@a5528eec7426a4f0c9c1ac96018daa53ebd05bc4 # v1.0.7 - with: - anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} - track_progress: true - prompt: | - REPO: ${{ github.repository }} - PR NUMBER: ${{ github.event.pull_request.number }} - TITLE: ${{ github.event.pull_request.title }} - BODY: ${{ github.event.pull_request.body }} - AUTHOR: ${{ github.event.pull_request.user.login }} - - Please review this pull request with a focus on: - - Code quality and best practices - - Potential bugs or issues - - Security implications - - Performance considerations - - Note: The PR branch is already checked out in the current working directory. - - Provide detailed feedback using inline comments for specific issues. - - claude_args: | - --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*)" diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml index 83d492645e..c683400a60 100644 --- a/.github/workflows/stale-bot.yml +++ b/.github/workflows/stale-bot.yml @@ -15,7 +15,7 @@ jobs: pull-requests: write steps: - name: Check - uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0 + uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 with: stale-issue-label: "needs-reply" stale-pr-label: "needs-changes" diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index 6bbc33299f..0fbdb5d069 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -44,10 +44,12 @@ jobs: checks: write steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Restore tools run: dotnet tool restore @@ -60,7 +62,7 @@ jobs: docker compose --profile mssql --profile postgres --profile mysql up -d shell: pwsh - - name: Add MariaDB for unified + - name: Add MariaDB for Bitwarden lite # Use a different port than MySQL run: | docker run --detach --name mariadb --env MARIADB_ROOT_PASSWORD=mariadb-password -p 4306:3306 mariadb:10 @@ -131,7 +133,7 @@ jobs: # Default Sqlite BW_TEST_DATABASES__3__TYPE: "Sqlite" BW_TEST_DATABASES__3__CONNECTIONSTRING: "Data Source=${{ runner.temp }}/test.db" - # Unified MariaDB + # Bitwarden lite MariaDB BW_TEST_DATABASES__4__TYPE: "MySql" BW_TEST_DATABASES__4__CONNECTIONSTRING: "server=localhost;port=4306;uid=root;pwd=mariadb-password;database=vault_dev;Allow User Variables=true" run: dotnet test --logger "trx;LogFileName=infrastructure-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" @@ -139,31 +141,31 @@ jobs: - name: Print MySQL Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mysql")' + run: 'docker logs "$(docker ps --quiet --filter "name=mysql")"' - name: Print MariaDB Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mariadb")' + run: 'docker logs "$(docker ps --quiet --filter "name=mariadb")"' - name: Print Postgres Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=postgres")' + run: 'docker logs "$(docker ps --quiet --filter "name=postgres")"' - name: Print MSSQL Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mssql")' + run: 'docker logs "$(docker ps --quiet --filter "name=mssql")"' - name: Report test results - uses: dorny/test-reporter@890a17cecf52a379fc869ab770a71657660be727 # v2.1.0 + uses: dorny/test-reporter@fe45e9537387dac839af0d33ba56eed8e24189e8 # v2.3.0 if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }} with: name: Test Results - path: "**/*-test-results.trx" + path: "./**/*-test-results.trx" reporter: dotnet-trx fail-on-error: true - name: Upload to codecov.io - uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 # v5.4.3 + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5.5.2 - name: Docker Compose down if: always() @@ -176,10 +178,12 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Print environment run: | @@ -193,7 +197,7 @@ jobs: shell: pwsh - name: Upload DACPAC - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: sql.dacpac path: Sql.dacpac @@ -219,7 +223,7 @@ jobs: shell: pwsh - name: Report validation results - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 with: name: report.xml path: | @@ -258,3 +262,26 @@ jobs: working-directory: "dev" run: docker compose down shell: pwsh + + validate-migration-naming: + name: Validate new migration naming and order + runs-on: ubuntu-22.04 + + steps: + - name: Check out repo + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Validate new migrations for pull request + if: github.event_name == 'pull_request' + run: | + git fetch origin main:main + pwsh dev/verify_migrations.ps1 -BaseRef main + shell: pwsh + + - name: Validate new migrations for push + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + run: pwsh dev/verify_migrations.ps1 -BaseRef HEAD~1 + shell: pwsh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4eed6df7ab..550d943dbc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,10 +27,20 @@ jobs: steps: - name: Check out repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 + + - name: Install rust + uses: dtolnay/rust-toolchain@b3b07ba8b418998c39fb20f53e8b695cdcc8de1b # stable + with: + toolchain: stable + + - name: Cache cargo registry + uses: Swatinem/rust-cache@779680da715d629ac1d338a641029a2f4372abb5 # v2.8.2 - name: Print environment run: | @@ -49,7 +59,7 @@ jobs: run: dotnet test ./bitwarden_license/test --configuration Debug --logger "trx;LogFileName=bw-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" - name: Report test results - uses: dorny/test-reporter@890a17cecf52a379fc869ab770a71657660be727 # v2.1.0 + uses: dorny/test-reporter@fe45e9537387dac839af0d33ba56eed8e24189e8 # v2.3.0 if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }} with: name: Test Results @@ -58,4 +68,4 @@ jobs: fail-on-error: true - name: Upload to codecov.io - uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 # v5.4.3 + uses: codecov/codecov-action@671740ac38dd9b0130fbe1cec585b89eea48d3de # v5.5.2 diff --git a/.gitignore b/.gitignore index 3b1f40e673..db8cb50f84 100644 --- a/.gitignore +++ b/.gitignore @@ -215,6 +215,9 @@ bitwarden_license/src/Sso/wwwroot/assets **/**.swp .mono src/Core/MailTemplates/Mjml/out +src/Core/MailTemplates/Mjml/out-hbs +NativeMethods.g.cs +util/RustSdk/rust/target src/Admin/Admin.zip src/Api/Api.zip @@ -231,3 +234,7 @@ bitwarden_license/src/Sso/Sso.zip /identity.json /api.json /api.public.json +.serena/ + +# Serena +.serena/ diff --git a/Directory.Build.props b/Directory.Build.props index 76f35e297e..db3ccf40f5 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,12 +3,10 @@ net8.0 - 2025.10.0 + 2025.12.2 Bit.$(MSBuildProjectName) enable - false - true annotations enable @@ -18,7 +16,7 @@ - 17.8.0 + 18.0.1 2.6.6 @@ -32,19 +30,4 @@ 4.18.1 - - - - - - - - - - - <_Parameter1>GitHash - <_Parameter2>$(SourceRevisionId) - - - \ No newline at end of file diff --git a/README.md b/README.md index c817931c67..6aa609bc8c 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,42 @@ Invoke-RestMethod -OutFile bitwarden.ps1 ` .\bitwarden.ps1 -start ``` +## Production Container Images + +
+View Current Production Image Hashes (click to expand) +
+ +### US Production Cluster + +| Service | Image Hash | +|---------|------------| +| **Admin** | ![admin](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.admin&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **API** | ![api](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.api&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Billing** | ![billing](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.billing&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Events** | ![events](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.events&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **EventsProcessor** | ![eventsprocessor](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.eventsprocessor&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Identity** | ![identity](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.identity&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Notifications** | ![notifications](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.notifications&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **SCIM** | ![scim](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.scim&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **SSO** | ![sso](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-us.json&query=%24.sso&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | + +### EU Production Cluster + +| Service | Image Hash | +|---------|------------| +| **Admin** | ![admin](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.admin&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **API** | ![api](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.api&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Billing** | ![billing](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.billing&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Events** | ![events](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.events&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **EventsProcessor** | ![eventsprocessor](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.eventsprocessor&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Identity** | ![identity](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.identity&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **Notifications** | ![notifications](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.notifications&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **SCIM** | ![scim](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.scim&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | +| **SSO** | ![sso](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fraw.githubusercontent.com%2Fbitwarden%2Fserver%2Frefs%2Fheads%2Fmetadata%2Fbadges%2Fshieldsio-badge-eu.json&query=%24.sso&style=flat-square&logo=docker&logoColor=white&label=&color=2496ED) | + +
+ ## We're Hiring! Interested in contributing in a big way? Consider joining our team! We're hiring for many positions. Please take a look at our [Careers page](https://bitwarden.com/careers/) to see what opportunities are currently open as well as what it's like to work at Bitwarden. diff --git a/bitwarden-server.sln b/bitwarden-server.sln index dbc37372a1..6786ad610c 100644 --- a/bitwarden-server.sln +++ b/bitwarden-server.sln @@ -133,8 +133,11 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Seeder", "util\Seeder\Seede EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DbSeederUtility", "util\DbSeederUtility\DbSeederUtility.csproj", "{17A89266-260A-4A03-81AE-C0468C6EE06E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RustSdk", "util\RustSdk\RustSdk.csproj", "{D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SharedWeb.Test", "test\SharedWeb.Test\SharedWeb.Test.csproj", "{AD59537D-5259-4B7A-948F-0CF58E80B359}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SSO.Test", "bitwarden_license\test\SSO.Test\SSO.Test.csproj", "{7D98784C-C253-43FB-9873-25B65C6250D6}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -339,10 +342,18 @@ Global {17A89266-260A-4A03-81AE-C0468C6EE06E}.Debug|Any CPU.Build.0 = Debug|Any CPU {17A89266-260A-4A03-81AE-C0468C6EE06E}.Release|Any CPU.ActiveCfg = Release|Any CPU {17A89266-260A-4A03-81AE-C0468C6EE06E}.Release|Any CPU.Build.0 = Release|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Release|Any CPU.Build.0 = Release|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Debug|Any CPU.Build.0 = Debug|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Release|Any CPU.ActiveCfg = Release|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Release|Any CPU.Build.0 = Release|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -397,7 +408,9 @@ Global {3631BA42-6731-4118-A917-DAA43C5032B9} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F} {9A612EBA-1C0E-42B8-982B-62F0EE81000A} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} {17A89266-260A-4A03-81AE-C0468C6EE06E} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} {AD59537D-5259-4B7A-948F-0CF58E80B359} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F} + {7D98784C-C253-43FB-9873-25B65C6250D6} = {287CFF34-BBDB-4BC4-AF88-1E19A5A4679B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {E01CBF68-2E20-425F-9EDB-E0A6510CA92F} diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 9ade2d660a..12d370395c 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -113,7 +113,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv await _providerBillingService.CreateCustomerForClientOrganization(provider, organization); } - var customer = await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + var customer = await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Description = string.Empty, Email = organization.BillingEmail, @@ -138,7 +138,7 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - var subscription = await _stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await _stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); organization.GatewaySubscriptionId = subscription.Id; organization.Status = OrganizationStatusType.Created; @@ -148,22 +148,29 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv } else if (organization.IsStripeEnabled()) { - var subscription = await _stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId); + var subscription = await _stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["customer"] + }); if (subscription.Status is StripeConstants.SubscriptionStatus.Canceled or StripeConstants.SubscriptionStatus.IncompleteExpired) { return; } - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { - Coupon = string.Empty, Email = organization.BillingEmail }); - await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions + if (subscription.Customer.Discount?.Coupon != null) + { + await _stripeAdapter.DeleteCustomerDiscountAsync(subscription.CustomerId); + } + + await _stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, - DaysUntilDue = 30 + DaysUntilDue = 30, }); await _subscriberService.RemovePaymentSource(organization); diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index aaf0050b63..4e8a23cf4e 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -9,12 +9,16 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -35,8 +39,9 @@ public class ProviderService : IProviderService { private static readonly PlanType[] _resellerDisallowedOrganizationTypes = [ PlanType.Free, - PlanType.FamiliesAnnually, - PlanType.FamiliesAnnually2019 + PlanType.FamiliesAnnually2025, + PlanType.FamiliesAnnually2019, + PlanType.FamiliesAnnually ]; private readonly IDataProtector _dataProtector; @@ -58,6 +63,7 @@ public class ProviderService : IProviderService private readonly IProviderBillingService _providerBillingService; private readonly IPricingClient _pricingClient; private readonly IProviderClientOrganizationSignUpCommand _providerClientOrganizationSignUpCommand; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public ProviderService(IProviderRepository providerRepository, IProviderUserRepository providerUserRepository, IProviderOrganizationRepository providerOrganizationRepository, IUserRepository userRepository, @@ -67,7 +73,8 @@ public class ProviderService : IProviderService ICurrentContext currentContext, IStripeAdapter stripeAdapter, IFeatureService featureService, IDataProtectorTokenFactory providerDeleteTokenDataFactory, IApplicationCacheService applicationCacheService, IProviderBillingService providerBillingService, IPricingClient pricingClient, - IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand) + IProviderClientOrganizationSignUpCommand providerClientOrganizationSignUpCommand, + IPolicyRequirementQuery policyRequirementQuery) { _providerRepository = providerRepository; _providerUserRepository = providerUserRepository; @@ -88,6 +95,7 @@ public class ProviderService : IProviderService _providerBillingService = providerBillingService; _pricingClient = pricingClient; _providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand; + _policyRequirementQuery = policyRequirementQuery; } public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) @@ -115,6 +123,18 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid owner."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(ownerUserId); + + if (organizationAutoConfirmPolicyRequirement + .CannotCreateProvider()) + { + throw new BadRequestException(new UserCannotJoinProvider().Message); + } + } + var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); @@ -247,6 +267,18 @@ public class ProviderService : IProviderService throw new BadRequestException("User email does not match invite."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(user.Id); + + if (organizationAutoConfirmPolicyRequirement + .CannotJoinProvider()) + { + throw new BadRequestException(new UserCannotJoinProvider().Message); + } + } + providerUser.Status = ProviderUserStatusType.Accepted; providerUser.UserId = user.Id; providerUser.Email = null; @@ -292,6 +324,19 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid user."); } + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var organizationAutoConfirmPolicyRequirement = await _policyRequirementQuery + .GetAsync(user.Id); + + if (organizationAutoConfirmPolicyRequirement + .CannotJoinProvider()) + { + result.Add(Tuple.Create(providerUser, new UserCannotJoinProvider().Message)); + continue; + } + } + providerUser.Status = ProviderUserStatusType.Confirmed; providerUser.Key = keys[providerUser.Id]; providerUser.Email = null; @@ -426,7 +471,7 @@ public class ProviderService : IProviderService if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = provider.BillingEmail }); @@ -486,7 +531,7 @@ public class ProviderService : IProviderService private async Task GetSubscriptionItemAsync(string subscriptionId, string oldPlanId) { - var subscriptionDetails = await _stripeAdapter.SubscriptionGetAsync(subscriptionId); + var subscriptionDetails = await _stripeAdapter.GetSubscriptionAsync(subscriptionId); return subscriptionDetails.Items.Data.FirstOrDefault(item => item.Price.Id == oldPlanId); } @@ -496,7 +541,7 @@ public class ProviderService : IProviderService { if (subscriptionItem.Price.Id != extractedPlanType) { - await _stripeAdapter.SubscriptionUpdateAsync(subscriptionItem.Subscription, + await _stripeAdapter.UpdateSubscriptionAsync(subscriptionItem.Subscription, new Stripe.SubscriptionUpdateOptions { Items = new List diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs index cc77797307..e140a13841 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Queries/GetProviderWarningsQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -76,8 +75,8 @@ public class GetProviderWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs index 8e8a89ae58..ce2f7a941f 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/BusinessUnitConverter.cs @@ -101,7 +101,7 @@ public class BusinessUnitConverter( providerUser.Status = ProviderUserStatusType.Confirmed; // Stripe requires that we clear all the custom fields from the invoice settings if we want to replace them. - await stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -116,7 +116,7 @@ public class BusinessUnitConverter( ["convertedFrom"] = organization.Id.ToString() }; - var updateCustomer = stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions + var updateCustomer = stripeAdapter.UpdateCustomerAsync(subscription.CustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -148,7 +148,7 @@ public class BusinessUnitConverter( // Replace the existing password manager price with the new business unit price. var updateSubscription = - stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index c9851eb403..7042a531d0 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -61,11 +61,11 @@ public class ProviderBillingService( Organization organization, string key) { - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = - await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, + await stripeAdapter.CancelSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionCancelOptions { CancellationDetails = new SubscriptionCancellationDetailsOptions @@ -83,7 +83,7 @@ public class ProviderBillingService( if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft) { - await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, + await stripeAdapter.FinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); } @@ -138,7 +138,7 @@ public class ProviderBillingService( if (clientCustomer.Balance != 0) { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, + await stripeAdapter.CreateCustomerBalanceTransactionAsync(provider.GatewayCustomerId, new CustomerBalanceTransactionCreateOptions { Amount = clientCustomer.Balance, @@ -187,7 +187,7 @@ public class ProviderBillingService( ] }; - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, updateOptions); // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan @@ -275,7 +275,7 @@ public class ProviderBillingService( customerCreateOptions.TaxExempt = TaxExempt.Reverse; } - var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); organization.GatewayCustomerId = customer.Id; @@ -481,7 +481,6 @@ public class ProviderBillingService( City = billingAddress.City, State = billingAddress.State }, - Coupon = !string.IsNullOrEmpty(provider.DiscountId) ? provider.DiscountId : null, Description = provider.DisplayBusinessName(), Email = provider.BillingEmail, InvoiceSettings = new CustomerInvoiceSettingsOptions @@ -526,7 +525,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) @@ -559,7 +558,7 @@ public class ProviderBillingService( try { - return await stripeAdapter.CustomerCreateAsync(options); + return await stripeAdapter.CreateCustomerAsync(options); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid) { @@ -581,7 +580,7 @@ public class ProviderBillingService( case TokenizablePaymentMethodType.BankAccount: { var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); - await stripeAdapter.SetupIntentCancel(setupIntentId, + await stripeAdapter.CancelSetupIntentAsync(setupIntentId, new SetupIntentCancelOptions { CancellationReason = "abandoned" }); await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); break; @@ -639,7 +638,7 @@ public class ProviderBillingService( var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntent = !string.IsNullOrEmpty(setupIntentId) - ? await stripeAdapter.SetupIntentGet(setupIntentId, + ? await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }) : null; @@ -663,6 +662,7 @@ public class ProviderBillingService( : CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = usePaymentMethod ? null : 30, + Discounts = !string.IsNullOrEmpty(provider.DiscountId) ? [new SubscriptionDiscountOptions { Coupon = provider.DiscountId }] : null, Items = subscriptionItemOptionsList, Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, @@ -671,10 +671,9 @@ public class ProviderBillingService( AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } }; - try { - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (subscription is { @@ -709,7 +708,7 @@ public class ProviderBillingService( subscriberService.UpdatePaymentSource(provider, tokenizedPaymentSource), subscriberService.UpdateTaxInformation(provider, taxInformation)); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically }); } @@ -792,11 +791,49 @@ public class ProviderBillingService( if (subscriptionItemOptionsList.Count > 0) { - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = subscriptionItemOptionsList }); } } + public async Task UpdateProviderNameAndEmail(Provider provider) + { + if (string.IsNullOrWhiteSpace(provider.GatewayCustomerId)) + { + logger.LogWarning( + "Provider ({ProviderId}) has no Stripe customer to update", + provider.Id); + return; + } + + var newDisplayName = provider.DisplayName(); + + // Provider.DisplayName() can return null - handle gracefully + if (string.IsNullOrWhiteSpace(newDisplayName)) + { + logger.LogWarning( + "Provider ({ProviderId}) has no name to update in Stripe", + provider.Id); + return; + } + + await stripeAdapter.UpdateCustomerAsync(provider.GatewayCustomerId, + new CustomerUpdateOptions + { + Email = provider.BillingEmail, + Description = newDisplayName, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + CustomFields = [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = provider.SubscriberType(), + Value = newDisplayName + }] + }, + }); + } + private Func CurrySeatScalingUpdate( Provider provider, ProviderPlan providerPlan, @@ -808,7 +845,7 @@ public class ProviderBillingService( var item = subscription.Items.First(item => item.Price.Id == priceId); - await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { Items = [ diff --git a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs new file mode 100644 index 0000000000..22421f9921 --- /dev/null +++ b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/Repositories/SecretVersionRepository.cs @@ -0,0 +1,94 @@ +using AutoMapper; +using Bit.Core.SecretsManager.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.SecretsManager.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Commercial.Infrastructure.EntityFramework.SecretsManager.Repositories; + +public class SecretVersionRepository : Repository, ISecretVersionRepository +{ + public SecretVersionRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, db => db.SecretVersion) + { } + + public override async Task GetByIdAsync(Guid id) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var secretVersion = await dbContext.SecretVersion + .Where(sv => sv.Id == id) + .FirstOrDefaultAsync(); + return Mapper.Map(secretVersion); + } + + public async Task> GetManyBySecretIdAsync(Guid secretId) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var secretVersions = await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretId) + .OrderByDescending(sv => sv.VersionDate) + .ToListAsync(); + return Mapper.Map>(secretVersions); + } + + public async Task> GetManyByIdsAsync(IEnumerable ids) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var versionIds = ids.ToList(); + var secretVersions = await dbContext.SecretVersion + .Where(sv => versionIds.Contains(sv.Id)) + .OrderByDescending(sv => sv.VersionDate) + .ToListAsync(); + return Mapper.Map>(secretVersions); + } + + public override async Task CreateAsync(Core.SecretsManager.Entities.SecretVersion secretVersion) + { + const int maxVersionsToKeep = 10; + + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + + await using var transaction = await dbContext.Database.BeginTransactionAsync(); + + // Get the IDs of the most recent (maxVersionsToKeep - 1) versions to keep + var versionsToKeepIds = await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretVersion.SecretId) + .OrderByDescending(sv => sv.VersionDate) + .Take(maxVersionsToKeep - 1) + .Select(sv => sv.Id) + .ToListAsync(); + + // Delete all versions for this secret that are not in the "keep" list + if (versionsToKeepIds.Any()) + { + await dbContext.SecretVersion + .Where(sv => sv.SecretId == secretVersion.SecretId && !versionsToKeepIds.Contains(sv.Id)) + .ExecuteDeleteAsync(); + } + + secretVersion.SetNewId(); + var entity = Mapper.Map(secretVersion); + + await dbContext.AddAsync(entity); + await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); + + return secretVersion; + } + + public async Task DeleteManyByIdAsync(IEnumerable ids) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + + var secretVersionIds = ids.ToList(); + await dbContext.SecretVersion + .Where(sv => secretVersionIds.Contains(sv.Id)) + .ExecuteDeleteAsync(); + } +} diff --git a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs index d6c8848079..ac52c40ba6 100644 --- a/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs +++ b/bitwarden_license/src/Commercial.Infrastructure.EntityFramework/SecretsManager/SecretsManagerEFServiceCollectionExtensions.cs @@ -10,6 +10,7 @@ public static class SecretsManagerEfServiceCollectionExtensions { services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); } diff --git a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs index e3c290c85f..88d6858cb8 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/GroupsController.cs @@ -61,17 +61,15 @@ public class GroupsController : Controller [HttpGet("")] public async Task Get( Guid organizationId, - [FromQuery] string filter, - [FromQuery] int? count, - [FromQuery] int? startIndex) + [FromQuery] GetGroupsQueryParamModel model) { - var groupsListQueryResult = await _getGroupsListQuery.GetGroupsListAsync(organizationId, filter, count, startIndex); + var groupsListQueryResult = await _getGroupsListQuery.GetGroupsListAsync(organizationId, model); var scimListResponseModel = new ScimListResponseModel { Resources = groupsListQueryResult.groupList.Select(g => new ScimGroupResponseModel(g)).ToList(), - ItemsPerPage = count.GetValueOrDefault(groupsListQueryResult.groupList.Count()), + ItemsPerPage = model.Count, TotalResults = groupsListQueryResult.totalResults, - StartIndex = startIndex.GetValueOrDefault(1), + StartIndex = model.StartIndex, }; return Ok(scimListResponseModel); } diff --git a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs index afbfa50bb4..91d79542b5 100644 --- a/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs +++ b/bitwarden_license/src/Scim/Controllers/v2/UsersController.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; diff --git a/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs b/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs index cc6546700b..f0a561a29f 100644 --- a/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs +++ b/bitwarden_license/src/Scim/Groups/GetGroupsListQuery.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Scim.Groups.Interfaces; +using Bit.Scim.Models; namespace Bit.Scim.Groups; @@ -16,10 +17,16 @@ public class GetGroupsListQuery : IGetGroupsListQuery _groupRepository = groupRepository; } - public async Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, string filter, int? count, int? startIndex) + public async Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync( + Guid organizationId, GetGroupsQueryParamModel groupQueryParams) { string nameFilter = null; string externalIdFilter = null; + + int count = groupQueryParams.Count; + int startIndex = groupQueryParams.StartIndex; + string filter = groupQueryParams.Filter; + if (!string.IsNullOrWhiteSpace(filter)) { if (filter.StartsWith("displayName eq ")) @@ -53,11 +60,11 @@ public class GetGroupsListQuery : IGetGroupsListQuery } totalResults = groupList.Count; } - else if (string.IsNullOrWhiteSpace(filter) && startIndex.HasValue && count.HasValue) + else if (string.IsNullOrWhiteSpace(filter)) { groupList = groups.OrderBy(g => g.Name) - .Skip(startIndex.Value - 1) - .Take(count.Value) + .Skip(startIndex - 1) + .Take(count) .ToList(); totalResults = groups.Count; } diff --git a/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs b/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs index 07ff044701..4b4ba09e1d 100644 --- a/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs +++ b/bitwarden_license/src/Scim/Groups/Interfaces/IGetGroupsListQuery.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Scim.Models; namespace Bit.Scim.Groups.Interfaces; public interface IGetGroupsListQuery { - Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, string filter, int? count, int? startIndex); + Task<(IEnumerable groupList, int totalResults)> GetGroupsListAsync(Guid organizationId, GetGroupsQueryParamModel model); } diff --git a/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs b/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs new file mode 100644 index 0000000000..5389727917 --- /dev/null +++ b/bitwarden_license/src/Scim/Models/GetGroupsQueryParamModel.cs @@ -0,0 +1,14 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Scim.Models; + +public class GetGroupsQueryParamModel +{ + public string Filter { get; init; } = string.Empty; + + [Range(1, int.MaxValue)] + public int Count { get; init; } = 50; + + [Range(1, int.MaxValue)] + public int StartIndex { get; init; } = 1; +} diff --git a/bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs b/bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs similarity index 91% rename from bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs rename to bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs index 27d7b6d9a1..cd50dbca61 100644 --- a/bitwarden_license/src/Scim/Models/GetUserQueryParamModel.cs +++ b/bitwarden_license/src/Scim/Models/GetUsersQueryParamModel.cs @@ -1,5 +1,7 @@ using System.ComponentModel.DataAnnotations; +namespace Bit.Scim.Models; + public class GetUsersQueryParamModel { public string Filter { get; init; } = string.Empty; diff --git a/bitwarden_license/src/Scim/Program.cs b/bitwarden_license/src/Scim/Program.cs index 92f12f59dd..02f2e00d32 100644 --- a/bitwarden_license/src/Scim/Program.cs +++ b/bitwarden_license/src/Scim/Program.cs @@ -11,21 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.ScimSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/bitwarden_license/src/Scim/Startup.cs b/bitwarden_license/src/Scim/Startup.cs index edbbf34aea..2a84faa8dd 100644 --- a/bitwarden_license/src/Scim/Startup.cs +++ b/bitwarden_license/src/Scim/Startup.cs @@ -94,11 +94,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs b/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs index a734635ebf..c7085eb6b9 100644 --- a/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs +++ b/bitwarden_license/src/Scim/Users/GetUsersListQuery.cs @@ -3,6 +3,7 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Scim.Models; using Bit.Scim.Users.Interfaces; namespace Bit.Scim.Users; diff --git a/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs b/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs index f584cb8e7b..04133c89eb 100644 --- a/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs +++ b/bitwarden_license/src/Scim/Users/Interfaces/IGetUsersListQuery.cs @@ -1,4 +1,5 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Scim.Models; namespace Bit.Scim.Users.Interfaces; diff --git a/bitwarden_license/src/Scim/Users/PatchUserCommand.cs b/bitwarden_license/src/Scim/Users/PatchUserCommand.cs index 6c983611ee..474557a9cb 100644 --- a/bitwarden_license/src/Scim/Users/PatchUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PatchUserCommand.cs @@ -1,5 +1,5 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; diff --git a/bitwarden_license/src/Scim/Users/PostUserCommand.cs b/bitwarden_license/src/Scim/Users/PostUserCommand.cs index 5b4a0c29cd..696d600348 100644 --- a/bitwarden_license/src/Scim/Users/PostUserCommand.cs +++ b/bitwarden_license/src/Scim/Users/PostUserCommand.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.E using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.Utilities.Commands; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -24,7 +25,7 @@ public class PostUserCommand( IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, - IPaymentService paymentService, + IStripePaymentService paymentService, IScimContext scimContext, IFeatureService featureService, IInviteOrganizationUsersCommand inviteOrganizationUsersCommand, diff --git a/bitwarden_license/src/Scim/appsettings.Development.json b/bitwarden_license/src/Scim/appsettings.Development.json index 32253a93c1..496d0c075f 100644 --- a/bitwarden_license/src/Scim/appsettings.Development.json +++ b/bitwarden_license/src/Scim/appsettings.Development.json @@ -30,6 +30,7 @@ }, "storage": { "connectionString": "UseDevelopmentStorage=true" - } + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/bitwarden_license/src/Scim/appsettings.json b/bitwarden_license/src/Scim/appsettings.json index dcdfeb3ede..18b7a7ca7b 100644 --- a/bitwarden_license/src/Scim/appsettings.json +++ b/bitwarden_license/src/Scim/appsettings.json @@ -30,9 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index 98a581e8ca..7141f8429d 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Security.Claims; +using System.Security.Claims; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; @@ -57,6 +54,7 @@ public class AccountController : Controller private readonly IDataProtectorTokenFactory _dataProtector; private readonly IOrganizationDomainRepository _organizationDomainRepository; private readonly IRegisterUserCommand _registerUserCommand; + private readonly IFeatureService _featureService; public AccountController( IAuthenticationSchemeProvider schemeProvider, @@ -77,7 +75,8 @@ public class AccountController : Controller Core.Services.IEventService eventService, IDataProtectorTokenFactory dataProtector, IOrganizationDomainRepository organizationDomainRepository, - IRegisterUserCommand registerUserCommand) + IRegisterUserCommand registerUserCommand, + IFeatureService featureService) { _schemeProvider = schemeProvider; _clientStore = clientStore; @@ -98,10 +97,11 @@ public class AccountController : Controller _dataProtector = dataProtector; _organizationDomainRepository = organizationDomainRepository; _registerUserCommand = registerUserCommand; + _featureService = featureService; } [HttpGet] - public async Task PreValidate(string domainHint) + public async Task PreValidateAsync(string domainHint) { try { @@ -160,10 +160,12 @@ public class AccountController : Controller } [HttpGet] - public async Task Login(string returnUrl) + public async Task LoginAsync(string returnUrl) { var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable if (!context.Parameters.AllKeys.Contains("domain_hint") || string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) { @@ -179,6 +181,7 @@ public class AccountController : Controller var domainHint = context.Parameters["domain_hint"]; var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); +#nullable restore if (organization == null) { @@ -198,12 +201,15 @@ public class AccountController : Controller returnUrl, state = context.Parameters["state"], userIdentifier = context.Parameters["session_state"], + ssoToken }); } [HttpGet] - public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier) + public IActionResult ExternalChallenge(string scheme, string returnUrl, string state, string userIdentifier, string ssoToken) { + ValidateSchemeAgainstSsoToken(scheme, ssoToken); + if (string.IsNullOrEmpty(returnUrl)) { returnUrl = "~/"; @@ -232,40 +238,101 @@ public class AccountController : Controller return Challenge(props, scheme); } + /// + /// Validates the scheme (organization ID) against the organization ID found in the ssoToken. + /// + /// The authentication scheme (organization ID) to validate. + /// The SSO token to validate against. + /// Thrown if the scheme (organization ID) does not match the organization ID found in the ssoToken. + private void ValidateSchemeAgainstSsoToken(string scheme, string ssoToken) + { + SsoTokenable tokenable; + + try + { + tokenable = _dataProtector.Unprotect(ssoToken); + } + catch + { + throw new Exception(_i18nService.T("InvalidSsoToken")); + } + + if (!Guid.TryParse(scheme, out var schemeOrgId) || tokenable.OrganizationId != schemeOrgId) + { + throw new Exception(_i18nService.T("SsoOrganizationIdMismatch")); + } + } + [HttpGet] public async Task ExternalCallback() { + // Feature flag (PM-24579): Prevent SSO on existing non-compliant users. + var preventOrgUserLoginIfStatusInvalid = + _featureService.IsEnabled(FeatureFlagKeys.PM24579_PreventSsoOnExistingNonCompliantUsers); + // Read external identity from the temporary cookie var result = await HttpContext.AuthenticateAsync( AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception(_i18nService.T("ExternalAuthenticationError")); - } - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); + if (preventOrgUserLoginIfStatusInvalid) + { + if (!result.Succeeded) + { + throw new Exception(_i18nService.T("ExternalAuthenticationError")); + } + } + else + { + if (result?.Succeeded != true) + { + throw new Exception(_i18nService.T("ExternalAuthenticationError")); + } + } // See if the user has logged in with this SSO provider before and has already been provisioned. // This is signified by the user existing in the User table and the SSOUser table for the SSO provider they're using. - var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + var (possibleSsoLinkedUser, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + + // We will look these up as required (lazy resolution) to avoid multiple DB hits. + Organization? organization = null; + OrganizationUser? orgUser = null; // The user has not authenticated with this SSO provider before. // They could have an existing Bitwarden account in the User table though. - if (user == null) + if (possibleSsoLinkedUser == null) { + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable // If we're manually linking to SSO, the user's external identifier will be passed as query string parameter. - var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? - result.Properties.Items["user_identifier"] : null; - user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); + var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") + ? result.Properties.Items["user_identifier"] + : null; + + var (resolvedUser, foundOrganization, foundOrCreatedOrgUser) = + await CreateUserAndOrgUserConditionallyAsync( + provider, + providerUserId, + claims, + userIdentifier, + ssoConfigData); +#nullable restore + + possibleSsoLinkedUser = resolvedUser; + + if (preventOrgUserLoginIfStatusInvalid) + { + organization = foundOrganization; + orgUser = foundOrCreatedOrgUser; + } } - // Either the user already authenticated with the SSO provider, or we've just provisioned them. - // Either way, we have associated the SSO login with a Bitwarden user. - // We will now sign the Bitwarden user in. - if (user != null) + if (preventOrgUserLoginIfStatusInvalid) { + User resolvedSsoLinkedUser = possibleSsoLinkedUser + ?? throw new Exception(_i18nService.T("UserShouldBeFound")); + + await PreventOrgUserLoginIfStatusInvalidAsync(organization, provider, orgUser, resolvedSsoLinkedUser); + // This allows us to collect any additional claims or properties // for the specific protocols used and store them in the local auth cookie. // this is typically used to store data needed for signout from those protocols. @@ -278,19 +345,52 @@ public class AccountController : Controller ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) + await HttpContext.SignInAsync( + new IdentityServerUser(resolvedSsoLinkedUser.Id.ToString()) + { + DisplayName = resolvedSsoLinkedUser.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } + else + { + // PM-24579: remove this else block with feature flag removal. + // Either the user already authenticated with the SSO provider, or we've just provisioned them. + // Either way, we have associated the SSO login with a Bitwarden user. + // We will now sign the Bitwarden user in. + if (possibleSsoLinkedUser != null) { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties + { + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync( + new IdentityServerUser(possibleSsoLinkedUser.Id.ToString()) + { + DisplayName = possibleSsoLinkedUser.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } } // Delete temporary cookie used during external authentication await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable // Retrieve return URL var returnUrl = result.Properties.Items["return_url"] ?? "~/"; +#nullable restore // Check if external login is in the context of an OIDC request var context = await _interaction.GetAuthorizationContextAsync(returnUrl); @@ -309,8 +409,10 @@ public class AccountController : Controller return Redirect(returnUrl); } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable [HttpGet] - public async Task Logout(string logoutId) + public async Task LogoutAsync(string logoutId) { // Build a model so the logged out page knows what to display var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); @@ -333,6 +435,7 @@ public class AccountController : Controller // This triggers a redirect to the external provider for sign-out return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); } + if (redirectUri != null) { return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); @@ -342,14 +445,22 @@ public class AccountController : Controller return Redirect("~/"); } } +#nullable restore /// /// Attempts to map the external identity to a Bitwarden user, through the SsoUser table, which holds the `externalId`. /// The claims on the external identity are used to determine an `externalId`, and that is used to find the appropriate `SsoUser` and `User` records. /// - private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> - FindUserFromExternalProviderAsync(AuthenticateResult result) + private async Task<( + User? possibleSsoUser, + string provider, + string providerUserId, + IEnumerable claims, + SsoConfigurationData config + )> FindUserFromExternalProviderAsync(AuthenticateResult result) { + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable var provider = result.Properties.Items["scheme"]; var orgId = new Guid(provider); var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); @@ -374,9 +485,10 @@ public class AccountController : Controller // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute // for the user identifier. static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier - && (c.Properties == null - || !c.Properties.TryGetValue(SamlPropertyKeys.ClaimFormat, out var claimFormat) - || claimFormat != SamlNameIdFormats.Transient); + && (c.Properties == null + || !c.Properties.TryGetValue(SamlPropertyKeys.ClaimFormat, + out var claimFormat) + || claimFormat != SamlNameIdFormats.Transient); // Try to determine the unique id of the external user (issued by the provider) // the most common claim type for that are the sub claim and the NameIdentifier @@ -391,6 +503,7 @@ public class AccountController : Controller externalUser.FindFirst("upn") ?? externalUser.FindFirst("eppn") ?? throw new Exception(_i18nService.T("UnknownUserId")); +#nullable restore // Remove the user id claim so we don't include it as an extra claim if/when we provision the user var claims = externalUser.Claims.ToList(); @@ -399,13 +512,15 @@ public class AccountController : Controller // find external user var providerUserId = userIdClaim.Value; - var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); + var possibleSsoUser = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); - return (user, provider, providerUserId, claims, ssoConfigData); + return (possibleSsoUser, provider, providerUserId, claims, ssoConfigData); } /// - /// Provision an SSO-linked Bitwarden user. + /// This function seeks to set up the org user record or create a new user record based on the conditions + /// below. + /// /// This handles three different scenarios: /// 1. Creating an SsoUser link for an existing User and OrganizationUser /// - User is a member of the organization, but hasn't authenticated with the org's SSO provider before. @@ -418,77 +533,100 @@ public class AccountController : Controller /// The external identity provider's user identifier. /// The claims from the external IdP. /// The user identifier used for manual SSO linking. - /// The SSO configuration for the organization. - /// The User to sign in. + /// The SSO configuration for the organization. + /// Guaranteed to return the user to sign in as well as the found organization and org user. /// An exception if the user cannot be provisioned as requested. - private async Task AutoProvisionUserAsync(string provider, string providerUserId, - IEnumerable claims, string userIdentifier, SsoConfigurationData config) + private async Task<(User resolvedUser, Organization foundOrganization, OrganizationUser foundOrgUser)> CreateUserAndOrgUserConditionallyAsync( + string provider, + string providerUserId, + IEnumerable claims, + string userIdentifier, + SsoConfigurationData ssoConfigData + ) { - var name = GetName(claims, config.GetAdditionalNameClaimTypes()); - var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); - if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) - { - email = providerUserId; - } + // Try to get the email from the claims as we don't know if we have a user record yet. + var name = GetName(claims, ssoConfigData.GetAdditionalNameClaimTypes()); + var email = TryGetEmailAddress(claims, ssoConfigData, providerUserId); - if (!Guid.TryParse(provider, out var orgId)) - { - // TODO: support non-org (server-wide) SSO in the future? - throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); - } - - User existingUser = null; + User? possibleExistingUser; if (string.IsNullOrWhiteSpace(userIdentifier)) { if (string.IsNullOrWhiteSpace(email)) { throw new Exception(_i18nService.T("CannotFindEmailClaim")); } - existingUser = await _userRepository.GetByEmailAsync(email); + + possibleExistingUser = await _userRepository.GetByEmailAsync(email); } else { - existingUser = await GetUserFromManualLinkingData(userIdentifier); + possibleExistingUser = await GetUserFromManualLinkingDataAsync(userIdentifier); } - // Try to find the OrganizationUser if it exists. - var (organization, orgUser) = await FindOrganizationUser(existingUser, email, orgId); + // Find the org (we error if we can't find an org because no org is not valid) + var organization = await GetOrganizationByProviderAsync(provider); + + // Try to find an org user (null org user possible and valid here) + var possibleOrgUser = await GetOrganizationUserByUserAndOrgIdOrEmailAsync(possibleExistingUser, organization.Id, email); //---------------------------------------------------- // Scenario 1: We've found the user in the User table //---------------------------------------------------- - if (existingUser != null) + if (possibleExistingUser != null) { - if (existingUser.UsesKeyConnector && - (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) + User guaranteedExistingUser = possibleExistingUser; + + if (guaranteedExistingUser.UsesKeyConnector && + (possibleOrgUser == null || possibleOrgUser.Status == OrganizationUserStatusType.Invited)) { throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); } - // If the user already exists in Bitwarden, we require that the user already be in the org, - // and that they are either Accepted or Confirmed. - if (orgUser == null) + OrganizationUser guaranteedOrgUser = possibleOrgUser ?? throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + + /* + * ---------------------------------------------------- + * Critical Code Check Here + * + * We want to ensure a user is not in the invited state + * explicitly. User's in the invited state should not + * be able to authenticate via SSO. + * + * See internal doc called "Added Context for SSO Login + * Flows" for further details. + * ---------------------------------------------------- + */ + if (guaranteedOrgUser.Status == OrganizationUserStatusType.Invited) { - // Org User is not created - no invite has been sent - throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + // Org User is invited – must accept via email first + throw new Exception( + _i18nService.T("AcceptInviteBeforeUsingSSO", organization.DisplayName())); } - EnsureOrgUserStatusAllowed(orgUser.Status, organization.DisplayName(), - allowedStatuses: [OrganizationUserStatusType.Accepted, OrganizationUserStatusType.Confirmed]); - + // If the user already exists in Bitwarden, we require that the user already be in the org, + // and that they are either Accepted or Confirmed. + EnforceAllowedOrgUserStatus( + guaranteedOrgUser.Status, + allowedStatuses: [ + OrganizationUserStatusType.Accepted, + OrganizationUserStatusType.Confirmed + ], + organization.DisplayName()); // Since we're in the auto-provisioning logic, this means that the user exists, but they have not // authenticated with the org's SSO provider before now (otherwise we wouldn't be auto-provisioning them). // We've verified that the user is Accepted or Confnirmed, so we can create an SsoUser link and proceed // with authentication. - await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); - return existingUser; + await CreateSsoUserRecordAsync(providerUserId, guaranteedExistingUser.Id, organization.Id, guaranteedOrgUser); + + return (guaranteedExistingUser, organization, guaranteedOrgUser); } // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one - if (orgUser == null && organization.Seats.HasValue) + if (possibleOrgUser == null && organization.Seats.HasValue) { - var occupiedSeats = await _organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + var occupiedSeats = + await _organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); var initialSeatCount = organization.Seats.Value; var availableSeats = initialSeatCount - occupiedSeats.Total; if (availableSeats < 1) @@ -506,8 +644,10 @@ public class AccountController : Controller { if (organization.Seats.Value != initialSeatCount) { - await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value); + await _organizationService.AdjustSeatsAsync(organization.Id, + initialSeatCount - organization.Seats.Value); } + _logger.LogInformation(e, "SSO auto provisioning failed"); throw new Exception(_i18nService.T("NoSeatsAvailable", organization.DisplayName())); } @@ -515,40 +655,62 @@ public class AccountController : Controller } // If the email domain is verified, we can mark the email as verified + if (string.IsNullOrWhiteSpace(email)) + { + throw new Exception(_i18nService.T("CannotFindEmailClaim")); + } + var emailVerified = false; var emailDomain = CoreHelpers.GetEmailDomain(email); if (!string.IsNullOrWhiteSpace(emailDomain)) { - var organizationDomain = await _organizationDomainRepository.GetDomainByOrgIdAndDomainNameAsync(orgId, emailDomain); + var organizationDomain = + await _organizationDomainRepository.GetDomainByOrgIdAndDomainNameAsync(organization.Id, emailDomain); emailVerified = organizationDomain?.VerifiedDate.HasValue ?? false; } //-------------------------------------------------- // Scenarios 2 and 3: We need to register a new user //-------------------------------------------------- - var user = new User + var newUser = new User { Name = name, Email = email, EmailVerified = emailVerified, ApiKey = CoreHelpers.SecureRandomString(30) }; - await _registerUserCommand.RegisterUser(user); + + /* + The feature flag is checked here so that we can send the new MJML welcome email templates. + The other organization invites flows have an OrganizationUser allowing the RegisterUserCommand the ability + to fetch the Organization. The old method RegisterUser(User) here does not have that context, so we need + to use a new method RegisterSSOAutoProvisionedUserAsync(User, Organization) to send the correct email. + [PM-28057]: Prefer RegisterSSOAutoProvisionedUserAsync for SSO auto-provisioned users. + TODO: Remove Feature flag: PM-28221 + */ + if (_featureService.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)) + { + await _registerUserCommand.RegisterSSOAutoProvisionedUserAsync(newUser, organization); + } + else + { + await _registerUserCommand.RegisterUser(newUser); + } // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email var twoFactorPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.TwoFactorAuthentication); if (twoFactorPolicy != null && twoFactorPolicy.Enabled) { - user.SetTwoFactorProviders(new Dictionary + newUser.SetTwoFactorProviders(new Dictionary { [TwoFactorProviderType.Email] = new TwoFactorProvider { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + MetaData = new Dictionary { ["Email"] = newUser.Email.ToLowerInvariant() }, Enabled = true } }); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + await _userService.UpdateTwoFactorProviderAsync(newUser, TwoFactorProviderType.Email); } //----------------------------------------------------------------- @@ -556,17 +718,18 @@ public class AccountController : Controller // This means that an invitation was not sent for this user and we // need to establish their invited status now. //----------------------------------------------------------------- - if (orgUser == null) + if (possibleOrgUser == null) { - orgUser = new OrganizationUser + possibleOrgUser = new OrganizationUser { - OrganizationId = orgId, - UserId = user.Id, + OrganizationId = organization.Id, + UserId = newUser.Id, Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Invited }; - await _organizationUserRepository.CreateAsync(orgUser); + await _organizationUserRepository.CreateAsync(possibleOrgUser); } + //----------------------------------------------------------------- // Scenario 3: There is already an existing OrganizationUser // That was established through an invitation. We just need to @@ -574,24 +737,68 @@ public class AccountController : Controller //----------------------------------------------------------------- else { - orgUser.UserId = user.Id; - await _organizationUserRepository.ReplaceAsync(orgUser); + possibleOrgUser.UserId = newUser.Id; + await _organizationUserRepository.ReplaceAsync(possibleOrgUser); } // Create the SsoUser record to link the user to the SSO provider. - await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); + await CreateSsoUserRecordAsync(providerUserId, newUser.Id, organization.Id, possibleOrgUser); - return user; + return (newUser, organization, possibleOrgUser); } - private async Task GetUserFromManualLinkingData(string userIdentifier) + /// + /// Validates an organization user is allowed to log in via SSO and blocks invalid statuses. + /// Lazily resolves the organization and organization user if not provided. + /// + /// The target organization; if null, resolved from provider. + /// The SSO scheme provider value (organization id as a GUID string). + /// The organization-user record; if null, looked up by user/org or user email for invited users. + /// The user attempting to sign in (existing or newly provisioned). + /// Thrown if the organization cannot be resolved from provider; + /// the organization user cannot be found; or the organization user status is not allowed. + private async Task PreventOrgUserLoginIfStatusInvalidAsync( + Organization? organization, + string provider, + OrganizationUser? orgUser, + User user) { - User user = null; + // Lazily get organization if not already known + organization ??= await GetOrganizationByProviderAsync(provider); + + // Lazily get the org user if not already known + orgUser ??= await GetOrganizationUserByUserAndOrgIdOrEmailAsync( + user, + organization.Id, + user.Email); + + if (orgUser != null) + { + // Invited is allowed at this point because we know the user is trying to accept an org invite. + EnforceAllowedOrgUserStatus( + orgUser.Status, + allowedStatuses: [ + OrganizationUserStatusType.Invited, + OrganizationUserStatusType.Accepted, + OrganizationUserStatusType.Confirmed, + ], + organization.DisplayName()); + } + else + { + throw new Exception(_i18nService.T("CouldNotFindOrganizationUser", user.Id, organization.Id)); + } + } + + private async Task GetUserFromManualLinkingDataAsync(string userIdentifier) + { + User? user = null; var split = userIdentifier.Split(","); if (split.Length < 2) { throw new Exception(_i18nService.T("InvalidUserIdentifier")); } + var userId = split[0]; var token = split[1]; @@ -611,64 +818,94 @@ public class AccountController : Controller throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); } } + return user; } - private async Task<(Organization, OrganizationUser)> FindOrganizationUser(User existingUser, string email, Guid orgId) + /// + /// Tries to get the organization by the provider which is org id for us as we use the scheme + /// to identify organizations - not identity providers. + /// + /// Org id string from SSO scheme property + /// Errors if the provider string is not a valid org id guid or if the org cannot be found by the id. + private async Task GetOrganizationByProviderAsync(string provider) { - OrganizationUser orgUser = null; - var organization = await _organizationRepository.GetByIdAsync(orgId); + if (!Guid.TryParse(provider, out var organizationId)) + { + // TODO: support non-org (server-wide) SSO in the future? + throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId); + if (organization == null) { - throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); + throw new Exception(_i18nService.T("CouldNotFindOrganization", organizationId)); } + return organization; + } + + /// + /// Attempts to get an for a given organization + /// by first checking for an existing user relationship, and if none is found, + /// by looking up an invited user via their email address. + /// + /// The existing user entity to be looked up in OrganizationUsers table. + /// Organization id from the provider data. + /// Email to use as a fallback in case of an invited user not in the Org Users + /// table yet. + private async Task GetOrganizationUserByUserAndOrgIdOrEmailAsync( + User? user, + Guid organizationId, + string? email) + { + OrganizationUser? orgUser = null; + // Try to find OrgUser via existing User Id. // This covers any OrganizationUser state after they have accepted an invite. - if (existingUser != null) + if (user != null) { - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); - orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); + orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == organizationId); } // If no Org User found by Existing User Id - search all the organization's users via email. // This covers users who are Invited but haven't accepted their invite yet. - orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); + if (email != null) + { + orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(organizationId, email); + } - return (organization, orgUser); + return orgUser; } - private void EnsureOrgUserStatusAllowed( - OrganizationUserStatusType status, - string organizationDisplayName, - params OrganizationUserStatusType[] allowedStatuses) + private void EnforceAllowedOrgUserStatus( + OrganizationUserStatusType statusToCheckAgainst, + OrganizationUserStatusType[] allowedStatuses, + string organizationDisplayNameForLogging) { // if this status is one of the allowed ones, just return - if (allowedStatuses.Contains(status)) + if (allowedStatuses.Contains(statusToCheckAgainst)) { return; } // otherwise throw the appropriate exception - switch (status) + switch (statusToCheckAgainst) { - case OrganizationUserStatusType.Invited: - // Org User is invited – must accept via email first - throw new Exception( - _i18nService.T("AcceptInviteBeforeUsingSSO", organizationDisplayName)); case OrganizationUserStatusType.Revoked: // Revoked users may not be (auto)‑provisioned throw new Exception( - _i18nService.T("OrganizationUserAccessRevoked", organizationDisplayName)); + _i18nService.T("OrganizationUserAccessRevoked", organizationDisplayNameForLogging)); default: // anything else is “unknown” throw new Exception( - _i18nService.T("OrganizationUserUnknownStatus", organizationDisplayName)); + _i18nService.T("OrganizationUserUnknownStatus", organizationDisplayNameForLogging)); } } - - private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) + private IActionResult InvalidJson(string errorMessageKey, Exception? ex = null) { Response.StatusCode = ex == null ? 400 : 500; return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) @@ -679,13 +916,13 @@ public class AccountController : Controller }); } - private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) + private string? TryGetEmailAddressFromClaims(IEnumerable claims, IEnumerable additionalClaimTypes) { var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, - SamlClaimTypes.Email, "mail", "emailaddress"); + filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, + SamlClaimTypes.Email, "mail", "emailaddress"); if (!string.IsNullOrWhiteSpace(email)) { return email; @@ -701,13 +938,15 @@ public class AccountController : Controller return null; } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) { var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, - SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); + filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, + SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); if (!string.IsNullOrWhiteSpace(name)) { return name; @@ -724,8 +963,10 @@ public class AccountController : Controller return null; } +#nullable restore - private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) + private async Task CreateSsoUserRecordAsync(string providerUserId, Guid userId, Guid orgId, + OrganizationUser orgUser) { // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); @@ -740,15 +981,12 @@ public class AccountController : Controller await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); } - var ssoUser = new SsoUser - { - ExternalId = providerUserId, - UserId = userId, - OrganizationId = orgId, - }; + var ssoUser = new SsoUser { ExternalId = providerUserId, UserId = userId, OrganizationId = orgId, }; await _ssoUserRepository.CreateAsync(ssoUser); } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, AuthenticationProperties localSignInProps) { @@ -769,18 +1007,6 @@ public class AccountController : Controller } } - private async Task GetProviderAsync(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) - { - return context.IdP; - } - var schemes = await _schemeProvider.GetAllSchemesAsync(); - var providers = schemes.Select(x => x.Name).ToList(); - return providers.FirstOrDefault(); - } - private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) { // Get context information (client name, post logout redirect URI and iframe for federated signout) @@ -811,10 +1037,31 @@ public class AccountController : Controller return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); } +#nullable restore + + /** + * Tries to get a user's email from the claims and SSO configuration data or the provider user id if + * the claims email extraction returns null. + */ + private string? TryGetEmailAddress( + IEnumerable claims, + SsoConfigurationData config, + string providerUserId) + { + var email = TryGetEmailAddressFromClaims(claims, config.GetAdditionalEmailClaimTypes()); + + // If email isn't populated from claims and providerUserId has @, assume it is the email. + if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) + { + email = providerUserId; + } + + return email; + } public bool IsNativeClient(DIM.AuthorizationRequest context) { return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); } } diff --git a/bitwarden_license/src/Sso/Program.cs b/bitwarden_license/src/Sso/Program.cs index 1a8ce6eb88..bac3bb3d13 100644 --- a/bitwarden_license/src/Sso/Program.cs +++ b/bitwarden_license/src/Sso/Program.cs @@ -1,5 +1,4 @@ using Bit.Core.Utilities; -using Serilog; namespace Bit.Sso; @@ -13,19 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= globalSettings.MinLogLevel.SsoSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index 3aeb9c6beb..2f83f3dad0 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -100,8 +100,6 @@ public class Startup IdentityModelEventSource.ShowPII = true; } - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); @@ -157,6 +155,6 @@ public class Startup app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/bitwarden_license/src/Sso/appsettings.Development.json b/bitwarden_license/src/Sso/appsettings.Development.json index 8aae281068..8e24d82528 100644 --- a/bitwarden_license/src/Sso/appsettings.Development.json +++ b/bitwarden_license/src/Sso/appsettings.Development.json @@ -24,6 +24,13 @@ "storage": { "connectionString": "UseDevelopmentStorage=true" }, - "developmentDirectory": "../../../dev" + "developmentDirectory": "../../../dev", + "pricingUri": "https://billingpricing.qa.bitwarden.pw", + "mail": { + "smtp": { + "host": "localhost", + "port": 10250 + } + } } } diff --git a/bitwarden_license/src/Sso/appsettings.json b/bitwarden_license/src/Sso/appsettings.json index 73c85044cc..9a5df42f7f 100644 --- a/bitwarden_license/src/Sso/appsettings.json +++ b/bitwarden_license/src/Sso/appsettings.json @@ -13,7 +13,11 @@ "mail": { "sendGridApiKey": "SECRET", "amazonConfigSetName": "Email", - "replyToEmail": "no-reply@bitwarden.com" + "replyToEmail": "no-reply@bitwarden.com", + "smtp": { + "host": "localhost", + "port": 10250 + } }, "identityServer": { "certificateThumbprint": "SECRET" diff --git a/bitwarden_license/src/Sso/package-lock.json b/bitwarden_license/src/Sso/package-lock.json index aeefbd69d7..f5e0468f87 100644 --- a/bitwarden_license/src/Sso/package-lock.json +++ b/bitwarden_license/src/Sso/package-lock.json @@ -17,9 +17,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } }, @@ -678,6 +678,7 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -704,6 +705,7 @@ "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -746,6 +748,16 @@ "ajv": "^8.8.2" } }, + "node_modules/baseline-browser-mapping": { + "version": "2.8.18", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.18.tgz", + "integrity": "sha512-UYmTpOBwgPScZpS4A+YbapwWuBwasxvO/2IOHArSsAhL/+ZdmATBXTex3t+l2hXwLVYK382ibr/nKoY9GKe86w==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/bootstrap": { "version": "5.3.6", "resolved": "https://registry.npmjs.org/bootstrap/-/bootstrap-5.3.6.tgz", @@ -780,9 +792,9 @@ } }, "node_modules/browserslist": { - "version": "4.25.4", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.25.4.tgz", - "integrity": "sha512-4jYpcjabC606xJ3kw2QwGEZKX0Aw7sgQdZCvIK9dhVSPh76BKo+C+btT1RRofH7B+8iNpEbgGNVWiLki5q93yg==", + "version": "4.26.3", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.26.3.tgz", + "integrity": "sha512-lAUU+02RFBuCKQPj/P6NgjlbCnLBMp4UtgTx7vNHd3XSIJF87s9a5rA3aH2yw3GS9DqZAUbOtZdCCiZeVRqt0w==", "dev": true, "funding": [ { @@ -799,10 +811,12 @@ } ], "license": "MIT", + "peer": true, "dependencies": { - "caniuse-lite": "^1.0.30001737", - "electron-to-chromium": "^1.5.211", - "node-releases": "^2.0.19", + "baseline-browser-mapping": "^2.8.9", + "caniuse-lite": "^1.0.30001746", + "electron-to-chromium": "^1.5.227", + "node-releases": "^2.0.21", "update-browserslist-db": "^1.1.3" }, "bin": { @@ -820,9 +834,9 @@ "license": "MIT" }, "node_modules/caniuse-lite": { - "version": "1.0.30001741", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001741.tgz", - "integrity": "sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==", + "version": "1.0.30001751", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001751.tgz", + "integrity": "sha512-A0QJhug0Ly64Ii3eIqHu5X51ebln3k4yTUkY1j8drqpWHVreg/VLijN48cZ1bYPiqOQuqpkIKnzr/Ul8V+p6Cw==", "dev": true, "funding": [ { @@ -974,9 +988,9 @@ } }, "node_modules/electron-to-chromium": { - "version": "1.5.215", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.215.tgz", - "integrity": "sha512-TIvGp57UpeNetj/wV/xpFNpWGb0b/ROw372lHPx5Aafx02gjTBtWnEEcaSX3W2dLM3OSdGGyHX/cHl01JQsLaQ==", + "version": "1.5.237", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.237.tgz", + "integrity": "sha512-icUt1NvfhGLar5lSWH3tHNzablaA5js3HVHacQimfP8ViEBOQv+L7DKEuHdbTZ0SKCO1ogTJTIL1Gwk9S6Qvcg==", "dev": true, "license": "ISC" }, @@ -1527,9 +1541,9 @@ "optional": true }, "node_modules/node-releases": { - "version": "2.0.20", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.20.tgz", - "integrity": "sha512-7gK6zSXEH6neM212JgfYFXe+GmZQM+fia5SsusuBIUgnPheLFBmIPhtFoAQRj8/7wASYQnbDlHPVwY0BefoFgA==", + "version": "2.0.26", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.26.tgz", + "integrity": "sha512-S2M9YimhSjBSvYnlr5/+umAnPHE++ODwt5e2Ij6FoX45HA/s4vHdkDx1eax2pAPeAOqu4s9b7ppahsyEFdVqQA==", "dev": true, "license": "MIT" }, @@ -1653,6 +1667,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -1859,11 +1874,12 @@ "license": "MIT" }, "node_modules/sass": { - "version": "1.91.0", - "resolved": "https://registry.npmjs.org/sass/-/sass-1.91.0.tgz", - "integrity": "sha512-aFOZHGf+ur+bp1bCHZ+u8otKGh77ZtmFyXDo4tlYvT7PWql41Kwd8wdkPqhhT+h2879IVblcHFglIMofsFd1EA==", + "version": "1.93.2", + "resolved": "https://registry.npmjs.org/sass/-/sass-1.93.2.tgz", + "integrity": "sha512-t+YPtOQHpGW1QWsh1CHQ5cPIr9lbbGZLZnbihP/D/qZj/yuV68m8qarcV17nvkOX81BCrvzAlq2klCQFZghyTg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "chokidar": "^4.0.0", "immutable": "^5.0.2", @@ -1921,9 +1937,9 @@ } }, "node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "license": "MIT", "dependencies": { @@ -2060,9 +2076,9 @@ } }, "node_modules/tapable": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.3.tgz", - "integrity": "sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "license": "MIT", "engines": { @@ -2201,11 +2217,12 @@ } }, "node_modules/webpack": { - "version": "5.101.3", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.101.3.tgz", - "integrity": "sha512-7b0dTKR3Ed//AD/6kkx/o7duS8H3f1a4w3BYpIriX4BzIhjkn4teo05cptsxvLesHFKK5KObnadmCHBwGc+51A==", + "version": "5.102.1", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.102.1.tgz", + "integrity": "sha512-7h/weGm9d/ywQ6qzJ+Xy+r9n/3qgp/thalBbpOi5i223dPXKi04IBtqPN9nTd+jBc7QKfvDbaBnFipYp4sJAUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/eslint-scope": "^3.7.7", "@types/estree": "^1.0.8", @@ -2215,7 +2232,7 @@ "@webassemblyjs/wasm-parser": "^1.14.1", "acorn": "^8.15.0", "acorn-import-phases": "^1.0.3", - "browserslist": "^4.24.0", + "browserslist": "^4.26.3", "chrome-trace-event": "^1.0.2", "enhanced-resolve": "^5.17.3", "es-module-lexer": "^1.2.1", @@ -2227,10 +2244,10 @@ "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^4.3.2", - "tapable": "^2.1.1", + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", "terser-webpack-plugin": "^5.3.11", - "watchpack": "^2.4.1", + "watchpack": "^2.4.4", "webpack-sources": "^3.3.3" }, "bin": { @@ -2255,6 +2272,7 @@ "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@discoveryjs/json-ext": "^0.5.0", "@webpack-cli/configtest": "^2.1.1", diff --git a/bitwarden_license/src/Sso/package.json b/bitwarden_license/src/Sso/package.json index 28f40f0d25..df46444aca 100644 --- a/bitwarden_license/src/Sso/package.json +++ b/bitwarden_license/src/Sso/package.json @@ -16,9 +16,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 9b9c41048b..810429d658 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -13,7 +13,7 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -131,7 +131,7 @@ public class RemoveOrganizationFromProviderCommandTests Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -156,18 +156,22 @@ public class RemoveOrganizationFromProviderCommandTests "b@example.com" ]); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(GetSubscription(organization.GatewaySubscriptionId)); + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is( + options => options.Expand.Contains("customer"))) + .Returns(GetSubscription(organization.GatewaySubscriptionId, organization.GatewayCustomerId)); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); var stripeAdapter = sutProvider.GetDependency(); - await stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, - Arg.Is(options => - options.Coupon == string.Empty && options.Email == "a@example.com")); + await stripeAdapter.Received(1).UpdateCustomerAsync(organization.GatewayCustomerId, + Arg.Is(options => options.Email == "a@example.com")); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); + + await stripeAdapter.Received(1).DeleteCustomerDiscountAsync(organization.GatewayCustomerId); + + await stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30)); @@ -205,7 +209,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -224,7 +228,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -237,14 +241,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -294,7 +298,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -313,7 +317,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Description == string.Empty && options.Email == organization.BillingEmail && options.Expand[0] == "tax" && @@ -326,14 +330,14 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "subscription_id" }); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); - await stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is(options => + await stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.Customer == organization.GatewayCustomerId && options.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && options.DaysUntilDue == 30 && @@ -368,10 +372,21 @@ public class RemoveOrganizationFromProviderCommandTests Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); } - private static Subscription GetSubscription(string subscriptionId) => + private static Subscription GetSubscription(string subscriptionId, string customerId) => new() { Id = subscriptionId, + CustomerId = customerId, + Customer = new Customer + { + Discount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-id" + } + } + }, Status = StripeConstants.SubscriptionStatus.Active, Items = new StripeList { @@ -403,7 +418,7 @@ public class RemoveOrganizationFromProviderCommandTests organization.PlanType = PlanType.TeamsMonthly; organization.Enabled = false; // Start with a disabled organization - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly).Returns(teamsMonthlyPlan); @@ -421,7 +436,7 @@ public class RemoveOrganizationFromProviderCommandTests var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Any()) + stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(new Customer { Id = "customer_id", @@ -431,7 +446,7 @@ public class RemoveOrganizationFromProviderCommandTests } }); - stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(new Subscription + stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(new Subscription { Id = "new_subscription_id" }); diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index e61cf5f97e..7ec11894ad 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -1,17 +1,23 @@ using Bit.Commercial.Core.AdminConsole.Services; using Bit.Commercial.Core.Test.AdminConsole.AutoFixture; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; using Bit.Core.AdminConsole.Models.Business.Tokenables; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -20,6 +26,7 @@ using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationFixtures; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tokens; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; @@ -99,6 +106,57 @@ public class ProviderServiceTests .ReplaceAsync(Arg.Is(pu => pu.UserId == user.Id && pu.ProviderId == provider.Id && pu.Key == key)); } + [Theory, BitAutoData] + public async Task CompleteSetupAsync_WithAutoConfirmEnabled_ThrowsUserCannotJoinProviderError(User user, Provider provider, + string key, + TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress, + [ProviderUser] ProviderUser providerUser, + SutProvider sutProvider) + { + providerUser.ProviderId = provider.Id; + providerUser.UserId = user.Id; + var userService = sutProvider.GetDependency(); + userService.GetUserByIdAsync(user.Id).Returns(user); + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); + + var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); + var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + var providerBillingService = sutProvider.GetDependency(); + + var customer = new Customer { Id = "customer_id" }; + providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer); + + var subscription = new Subscription { Id = "subscription_id" }; + providerBillingService.SetupSubscription(provider).Returns(subscription); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List { new() { OrganizationId = Guid.NewGuid(), IsProvider = false } }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + sutProvider.Create(); + + var token = protector.Protect( + $"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, + billingAddress)); + + Assert.Equal(new UserCannotJoinProvider().Message, exception.Message); + } + [Theory, BitAutoData] public async Task UpdateAsync_ProviderIdIsInvalid_Throws(Provider provider, SutProvider sutProvider) { @@ -578,6 +636,132 @@ public class ProviderServiceTests Assert.Equal(user.Id, pu.UserId); } + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmEnabledAndPolicyExists_Throws( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List + { + new() { OrganizationId = Guid.NewGuid(), IsProvider = false } + }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token)); + + Assert.Equal(new UserCannotJoinProvider().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmEnabledButNoPolicyExists_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement([]); + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(policyRequirement); + + // Act + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + + // Assert + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + } + + [Theory, BitAutoData] + public async Task AcceptUserAsync_WithAutoConfirmDisabled_Success( + [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser providerUser, + User user, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(providerUser.Id) + .Returns(providerUser); + + var protector = DataProtectionProvider + .Create("ApplicationName") + .CreateProtector("ProviderServiceDataProtector"); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(protector); + sutProvider.Create(); + + providerUser.Email = user.Email; + var token = protector.Protect($"ProviderUserInvite {providerUser.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(false); + + // Act + var pu = await sutProvider.Sut.AcceptUserAsync(providerUser.Id, user, token); + + // Assert + Assert.Null(pu.Email); + Assert.Equal(ProviderUserStatusType.Accepted, pu.Status); + Assert.Equal(user.Id, pu.UserId); + + // Verify that policy check was never called when feature flag is disabled + await sutProvider.GetDependency() + .DidNotReceive() + .GetAsync(user.Id); + } + [Theory, BitAutoData] public async Task ConfirmUsersAsync_NoValid( [ProviderUser(ProviderUserStatusType.Invited)] ProviderUser pu1, @@ -624,13 +808,131 @@ public class ProviderServiceTests Assert.Equal("Invalid user.", result[2].Item2); } + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabledAndPolicyExists_ReturnsError( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyDetails = new List + { + new() { OrganizationId = Guid.NewGuid(), IsProvider = false } + }; + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(policyDetails); + sutProvider.GetDependency() + .GetAsync(u1.Id) + .Returns(policyRequirement); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal(new UserCannotJoinProvider().Message, result[0].Item2); + + // Verify user was not confirmed + await providerUserRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabledButNoPolicyExists_Success( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + var policyRequirement = new AutomaticUserConfirmationPolicyRequirement(new List()); + sutProvider.GetDependency() + .GetAsync(u1.Id) + .Returns(policyRequirement); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal("", result[0].Item2); + + // Verify user was confirmed + await providerUserRepository.Received(1).ReplaceAsync(Arg.Is(pu => + pu.Status == ProviderUserStatusType.Confirmed)); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmDisabled_Success( + [ProviderUser(ProviderUserStatusType.Accepted)] ProviderUser pu1, User u1, + Provider provider, User confirmingUser, SutProvider sutProvider) + { + // Arrange + pu1.ProviderId = provider.Id; + pu1.UserId = u1.Id; + var providerUsers = new[] { pu1 }; + + var providerUserRepository = sutProvider.GetDependency(); + providerUserRepository.GetManyAsync([]).ReturnsForAnyArgs(providerUsers); + + sutProvider.GetDependency().GetByIdAsync(provider.Id).Returns(provider); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([u1]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(false); + + var dict = providerUsers.ToDictionary(pu => pu.Id, _ => "key"); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(pu1.ProviderId, dict, confirmingUser.Id); + + // Assert + Assert.Single(result); + Assert.Equal("", result[0].Item2); + + // Verify user was confirmed + await providerUserRepository.Received(1).ReplaceAsync(Arg.Is(pu => + pu.Status == ProviderUserStatusType.Confirmed)); + + // Verify that policy check was never called when feature flag is disabled + await sutProvider.GetDependency() + .DidNotReceive() + .GetAsync(Arg.Any()); + } + [Theory, BitAutoData] public async Task SaveUserAsync_UserIdIsInvalid_Throws(ProviderUser providerUser, SutProvider sutProvider) { - providerUser.Id = default; - var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveUserAsync(providerUser, default)); + providerUser.Id = Guid.Empty; + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.SaveUserAsync(providerUser, Guid.Empty)); Assert.Equal("Invite the user first.", exception.Message); } @@ -756,7 +1058,7 @@ public class ProviderServiceTests await organizationRepository.Received(1) .ReplaceAsync(Arg.Is(org => org.BillingEmail == provider.BillingEmail)); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync( organization.GatewayCustomerId, Arg.Is(options => options.Email == provider.BillingEmail)); @@ -811,12 +1113,12 @@ public class ProviderServiceTests organization.Plan = "Enterprise (Monthly)"; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var expectedPlanType = PlanType.EnterpriseMonthly2020; sutProvider.GetDependency().GetPlanOrThrow(expectedPlanType) - .Returns(StaticStore.GetPlan(expectedPlanType)); + .Returns(MockPlans.Get(expectedPlanType)); var expectedPlanId = "2020-enterprise-org-seat-monthly"; @@ -827,9 +1129,9 @@ public class ProviderServiceTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var subscriptionItem = GetSubscription(organization.GatewaySubscriptionId); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) + sutProvider.GetDependency().GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(GetSubscription(organization.GatewaySubscriptionId)); - await sutProvider.GetDependency().SubscriptionUpdateAsync( + await sutProvider.GetDependency().UpdateSubscriptionAsync( organization.GatewaySubscriptionId, SubscriptionUpdateRequest(expectedPlanId, subscriptionItem)); await sutProvider.Sut.AddOrganization(provider.Id, organization.Id, key); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs index a7f896ef7a..96dbacfa92 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Queries/GetProviderWarningsQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -63,7 +62,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -95,7 +94,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -129,7 +128,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(false); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -163,7 +162,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -224,7 +223,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "GB" }] @@ -257,7 +256,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -296,7 +295,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -338,7 +337,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -383,7 +382,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -428,7 +427,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -461,7 +460,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Active)) .Returns(new StripeList { Data = [ @@ -470,7 +469,7 @@ public class GetProviderWarningsQueryTests new Registration { Country = "FR" } ] }); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Is(opt => opt.Status == TaxRegistrationStatus.Scheduled)) .Returns(new StripeList { Data = [] }); var response = await sutProvider.Sut.Run(provider); @@ -505,7 +504,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "CA" }] @@ -543,7 +542,7 @@ public class GetProviderWarningsQueryTests }); sutProvider.GetDependency().ProviderProviderAdmin(provider.Id).Returns(true); - sutProvider.GetDependency().TaxRegistrationsListAsync(Arg.Any()) + sutProvider.GetDependency().ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = [new Registration { Country = "US" }] diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs index ec52650097..48b971a032 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/BusinessUnitConverterTests.cs @@ -18,6 +18,7 @@ using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.DataProtection; @@ -72,7 +73,7 @@ public class BusinessUnitConverterTests { organization.PlanType = PlanType.EnterpriseAnnually2020; - var enterpriseAnnually2020 = StaticStore.GetPlan(PlanType.EnterpriseAnnually2020); + var enterpriseAnnually2020 = MockPlans.Get(PlanType.EnterpriseAnnually2020); var subscription = new Subscription { @@ -134,7 +135,7 @@ public class BusinessUnitConverterTests _pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually2020) .Returns(enterpriseAnnually2020); - var enterpriseAnnually = StaticStore.GetPlan(PlanType.EnterpriseAnnually); + var enterpriseAnnually = MockPlans.Get(PlanType.EnterpriseAnnually); _pricingClient.GetPlanOrThrow(PlanType.EnterpriseAnnually) .Returns(enterpriseAnnually); @@ -143,11 +144,11 @@ public class BusinessUnitConverterTests await businessUnitConverter.FinalizeConversion(organization, userId, token, providerKey, organizationKey); - await _stripeAdapter.Received(2).CustomerUpdateAsync(subscription.CustomerId, Arg.Any()); + await _stripeAdapter.Received(2).UpdateCustomerAsync(subscription.CustomerId, Arg.Any()); var updatedPriceId = ProviderPriceAdapter.GetActivePriceId(provider, enterpriseAnnually.Type); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(subscription.Id, Arg.Is( + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(subscription.Id, Arg.Is( arguments => arguments.Items.Count == 2 && arguments.Items[0].Id == "subscription_item_id" && @@ -242,7 +243,7 @@ public class BusinessUnitConverterTests argument.Status == ProviderStatusType.Pending && argument.Type == ProviderType.BusinessUnit)).Returns(provider); - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan); diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 18c71364e6..93ce33edc4 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -20,9 +20,8 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Braintree; @@ -85,7 +84,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -113,7 +112,7 @@ public class ProviderBillingServiceTests // Assert await providerPlanRepository.Received(0).ReplaceAsync(Arg.Any()); - await stripeAdapter.Received(0).SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.Received(0).UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -140,7 +139,7 @@ public class ProviderBillingServiceTests .Returns(existingPlan); sutProvider.GetDependency().GetPlanOrThrow(existingPlan.PlanType) - .Returns(StaticStore.GetPlan(existingPlan.PlanType)); + .Returns(MockPlans.Get(existingPlan.PlanType)); sutProvider.GetDependency().GetSubscriptionOrThrow(provider) .Returns(new Subscription @@ -155,7 +154,7 @@ public class ProviderBillingServiceTests Id = "si_ent_annual", Price = new Price { - Id = StaticStore.GetPlan(PlanType.EnterpriseAnnually).PasswordManager + Id = MockPlans.Get(PlanType.EnterpriseAnnually).PasswordManager .StripeProviderPortalSeatPlanId }, Quantity = 10 @@ -168,7 +167,7 @@ public class ProviderBillingServiceTests new ChangeProviderPlanCommand(provider, providerPlanId, PlanType.EnterpriseMonthly); sutProvider.GetDependency().GetPlanOrThrow(command.NewPlan) - .Returns(StaticStore.GetPlan(command.NewPlan)); + .Returns(MockPlans.Get(command.NewPlan)); // Act await sutProvider.Sut.ChangePlan(command); @@ -180,14 +179,14 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => si.Id == "si_ent_annual" && si.Deleted == true) == 1)); - var newPlanCfg = StaticStore.GetPlan(command.NewPlan); + var newPlanCfg = MockPlans.Get(command.NewPlan); await stripeAdapter.Received(1) - .SubscriptionUpdateAsync( + .UpdateSubscriptionAsync( Arg.Is(provider.GatewaySubscriptionId), Arg.Is(p => p.Items.Count(si => @@ -268,7 +267,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -288,7 +287,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -349,7 +348,7 @@ public class ProviderBillingServiceTests CloudRegion = "US" }); - sutProvider.GetDependency().CustomerCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -370,7 +369,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.CreateCustomerForClientOrganization(provider, organization); - await sutProvider.GetDependency().Received(1).CustomerCreateAsync(Arg.Is( + await sutProvider.GetDependency().Received(1).CreateCustomerAsync(Arg.Is( options => options.Address.Country == providerCustomer.Address.Country && options.Address.PostalCode == providerCustomer.Address.PostalCode && @@ -491,7 +490,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); @@ -514,7 +513,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 50 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -535,7 +534,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 50 assigned seats + 10 seat scale up = 60 seats, well below the 100 minimum - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().SubscriptionUpdateAsync( + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UpdateSubscriptionAsync( Arg.Any(), Arg.Any()); @@ -573,7 +572,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -598,7 +597,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 95 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -619,7 +618,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 95 current + 10 seat scale = 105 seats, 5 above the minimum - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -661,7 +660,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -686,7 +685,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 110 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -707,7 +706,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, 10); // 110 current + 10 seat scale up = 120 seats - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -749,7 +748,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } var providerPlan = providerPlans.First(); @@ -774,7 +773,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSubscriptionOrThrow(provider).Returns(subscription); // 110 seats currently assigned with a seat minimum of 100 - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ @@ -795,7 +794,7 @@ public class ProviderBillingServiceTests await sutProvider.Sut.ScaleSeats(provider, PlanType.TeamsMonthly, -30); // 110 seats - 30 scale down seats = 80 seats, below the 100 seat minimum. - await sutProvider.GetDependency().Received(1).SubscriptionUpdateAsync( + await sutProvider.GetDependency().Received(1).UpdateSubscriptionAsync( provider.GatewaySubscriptionId, Arg.Is( options => @@ -827,13 +826,13 @@ public class ProviderBillingServiceTests } ]); - sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(StaticStore.GetPlan(planType)); + sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(MockPlans.Get(planType)); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ new ProviderOrganizationOrganizationDetails { - Plan = StaticStore.GetPlan(planType).Name, + Plan = MockPlans.Get(planType).Name, Status = OrganizationStatusType.Managed, Seats = 5 } @@ -865,13 +864,13 @@ public class ProviderBillingServiceTests } ]); - sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(StaticStore.GetPlan(planType)); + sutProvider.GetDependency().GetPlanOrThrow(planType).Returns(MockPlans.Get(planType)); sutProvider.GetDependency().GetManyDetailsByProviderAsync(provider.Id).Returns( [ new ProviderOrganizationOrganizationDetails { - Plan = StaticStore.GetPlan(planType).Name, + Plan = MockPlans.Get(planType).Name, Status = OrganizationStatusType.Managed, Seats = 15 } @@ -914,12 +913,12 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -942,7 +941,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); - await stripeAdapter.Received(1).SetupIntentCancel("setup_intent_id", Arg.Is(options => + await stripeAdapter.Received(1).CancelSetupIntentAsync("setup_intent_id", Arg.Is(options => options.CancellationReason == "abandoned")); await sutProvider.GetDependency().Received(1).RemoveSetupIntentForSubscriber(provider.Id); @@ -964,7 +963,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1007,12 +1006,12 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; - stripeAdapter.SetupIntentList(Arg.Is(options => + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1058,7 +1057,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1100,7 +1099,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1142,7 +1141,7 @@ public class ProviderBillingServiceTests var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Is(o => + stripeAdapter.CreateCustomerAsync(Arg.Is(o => o.Address.Country == billingAddress.Country && o.Address.PostalCode == billingAddress.PostalCode && o.Address.Line1 == billingAddress.Line1 && @@ -1178,7 +1177,7 @@ public class ProviderBillingServiceTests var stripeAdapter = sutProvider.GetDependency(); var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - stripeAdapter.CustomerCreateAsync(Arg.Any()) + stripeAdapter.CreateCustomerAsync(Arg.Any()) .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } }); var actual = await Assert.ThrowsAsync(async () => @@ -1216,7 +1215,7 @@ public class ProviderBillingServiceTests await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1238,13 +1237,13 @@ public class ProviderBillingServiceTests .Returns(providerPlans); sutProvider.GetDependency().GetPlanOrThrow(PlanType.EnterpriseMonthly) - .Returns(StaticStore.GetPlan(PlanType.EnterpriseMonthly)); + .Returns(MockPlans.Get(PlanType.EnterpriseMonthly)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1266,13 +1265,13 @@ public class ProviderBillingServiceTests .Returns(providerPlans); sutProvider.GetDependency().GetPlanOrThrow(PlanType.TeamsMonthly) - .Returns(StaticStore.GetPlan(PlanType.TeamsMonthly)); + .Returns(MockPlans.Get(PlanType.TeamsMonthly)); await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupSubscription(provider)); await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .SubscriptionCreateAsync(Arg.Any()); + .CreateSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1317,13 +1316,13 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) .Returns(providerPlans); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Any()) + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Any()) .Returns( new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Incomplete }); @@ -1373,7 +1372,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1381,7 +1380,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.SendInvoice && @@ -1449,7 +1448,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1458,7 +1457,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1525,7 +1524,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1538,7 +1537,7 @@ public class ProviderBillingServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntentId); - sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is(options => + sutProvider.GetDependency().GetSetupIntentAsync(setupIntentId, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(new SetupIntent { Id = setupIntentId, @@ -1553,7 +1552,7 @@ public class ProviderBillingServiceTests } }); - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1626,7 +1625,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1635,7 +1634,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1704,7 +1703,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } sutProvider.GetDependency().GetByProviderId(provider.Id) @@ -1713,7 +1712,7 @@ public class ProviderBillingServiceTests var expected = new Subscription { Id = "subscription_id", Status = StripeConstants.SubscriptionStatus.Active }; - sutProvider.GetDependency().SubscriptionCreateAsync(Arg.Is( + sutProvider.GetDependency().CreateSubscriptionAsync(Arg.Is( sub => sub.AutomaticTax.Enabled == true && sub.CollectionMethod == StripeConstants.CollectionMethod.ChargeAutomatically && @@ -1772,8 +1771,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1806,7 +1805,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1828,7 +1827,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 20 && providerPlan.PurchasedSeats == 5)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1852,8 +1851,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1886,7 +1885,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1908,7 +1907,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 50)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -1932,8 +1931,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -1966,7 +1965,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -1989,7 +1988,7 @@ public class ProviderBillingServiceTests providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 60 && providerPlan.PurchasedSeats == 10)); await stripeAdapter.DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -2006,8 +2005,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -2040,7 +2039,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -2062,7 +2061,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.Received(1).ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly && providerPlan.SeatMinimum == 80 && providerPlan.PurchasedSeats == 0)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 2 && @@ -2086,8 +2085,8 @@ public class ProviderBillingServiceTests const string enterpriseLineItemId = "enterprise_line_item_id"; const string teamsLineItemId = "teams_line_item_id"; - var enterprisePriceId = StaticStore.GetPlan(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; - var teamsPriceId = StaticStore.GetPlan(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var enterprisePriceId = MockPlans.Get(PlanType.EnterpriseMonthly).PasswordManager.StripeProviderPortalSeatPlanId; + var teamsPriceId = MockPlans.Get(PlanType.TeamsMonthly).PasswordManager.StripeProviderPortalSeatPlanId; var subscription = new Subscription { @@ -2120,7 +2119,7 @@ public class ProviderBillingServiceTests foreach (var plan in providerPlans) { sutProvider.GetDependency().GetPlanOrThrow(plan.PlanType) - .Returns(StaticStore.GetPlan(plan.PlanType)); + .Returns(MockPlans.Get(plan.PlanType)); } providerPlanRepository.GetByProviderId(provider.Id).Returns(providerPlans); @@ -2142,7 +2141,7 @@ public class ProviderBillingServiceTests await providerPlanRepository.DidNotReceive().ReplaceAsync(Arg.Is( providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Items.Count == 1 && @@ -2151,4 +2150,151 @@ public class ProviderBillingServiceTests } #endregion + + #region UpdateProviderNameAndEmail + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullGatewayCustomerId_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.GatewayCustomerId = null; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_EmptyGatewayCustomerId_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.GatewayCustomerId = ""; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullProviderName_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = null; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_EmptyProviderName_LogsWarningAndReturns( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = ""; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_ValidProvider_CallsStripeWithCorrectParameters( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = "Test Provider"; + provider.BillingEmail = "billing@test.com"; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.Email == provider.BillingEmail && + options.Description == provider.Name && + options.InvoiceSettings.CustomFields.Count == 1 && + options.InvoiceSettings.CustomFields[0].Name == "Provider" && + options.InvoiceSettings.CustomFields[0].Value == provider.Name)); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_LongProviderName_UsesFullName( + Provider provider, + SutProvider sutProvider) + { + // Arrange + var longName = new string('A', 50); // 50 characters + provider.Name = longName; + provider.BillingEmail = "billing@test.com"; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.InvoiceSettings.CustomFields[0].Value == longName)); + } + + [Theory, BitAutoData] + public async Task UpdateProviderNameAndEmail_NullBillingEmail_UpdatesWithNull( + Provider provider, + SutProvider sutProvider) + { + // Arrange + provider.Name = "Test Provider"; + provider.BillingEmail = null; + provider.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateProviderNameAndEmail(provider); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + provider.GatewayCustomerId, + Arg.Is(options => + options.Email == null && + options.Description == provider.Name)); + } + + #endregion } diff --git a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs index 16ae8f7f2c..776403fdd5 100644 --- a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Queries/Projects/MaxProjectsQueryTests.cs @@ -6,7 +6,7 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Settings; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -69,7 +69,7 @@ public class MaxProjectsQueryTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetPlan(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var (limit, overLimit) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, 1); @@ -114,7 +114,7 @@ public class MaxProjectsQueryTests .Returns(projects); sutProvider.GetDependency().GetPlan(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var (max, overMax) = await sutProvider.Sut.GetByOrgIdAsync(organization.Id, projectsToAdd); diff --git a/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs new file mode 100644 index 0000000000..659a6d1233 --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/SecretsManager/Repositories/SecretVersionRepositoryTests.cs @@ -0,0 +1,130 @@ +using Bit.Core.SecretsManager.Entities; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Commercial.Core.Test.SecretsManager.Repositories; + +public class SecretVersionRepositoryTests +{ + [Theory] + [BitAutoData] + public void SecretVersion_EntityCreation_Success(SecretVersion secretVersion) + { + // Arrange & Act + secretVersion.SetNewId(); + + // Assert + Assert.NotEqual(Guid.Empty, secretVersion.Id); + Assert.NotEqual(Guid.Empty, secretVersion.SecretId); + Assert.NotNull(secretVersion.Value); + Assert.NotEqual(default, secretVersion.VersionDate); + } + + [Theory] + [BitAutoData] + public void SecretVersion_WithServiceAccountEditor_Success(SecretVersion secretVersion, Guid serviceAccountId) + { + // Arrange & Act + secretVersion.EditorServiceAccountId = serviceAccountId; + secretVersion.EditorOrganizationUserId = null; + + // Assert + Assert.Equal(serviceAccountId, secretVersion.EditorServiceAccountId); + Assert.Null(secretVersion.EditorOrganizationUserId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_WithOrganizationUserEditor_Success(SecretVersion secretVersion, Guid organizationUserId) + { + // Arrange & Act + secretVersion.EditorOrganizationUserId = organizationUserId; + secretVersion.EditorServiceAccountId = null; + + // Assert + Assert.Equal(organizationUserId, secretVersion.EditorOrganizationUserId); + Assert.Null(secretVersion.EditorServiceAccountId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_NullableEditors_Success(SecretVersion secretVersion) + { + // Arrange & Act + secretVersion.EditorServiceAccountId = null; + secretVersion.EditorOrganizationUserId = null; + + // Assert + Assert.Null(secretVersion.EditorServiceAccountId); + Assert.Null(secretVersion.EditorOrganizationUserId); + } + + [Theory] + [BitAutoData] + public void SecretVersion_VersionDateSet_Success(SecretVersion secretVersion) + { + // Arrange + var versionDate = DateTime.UtcNow; + + // Act + secretVersion.VersionDate = versionDate; + + // Assert + Assert.Equal(versionDate, secretVersion.VersionDate); + } + + [Theory] + [BitAutoData] + public void SecretVersion_ValueEncrypted_Success(SecretVersion secretVersion, string encryptedValue) + { + // Arrange & Act + secretVersion.Value = encryptedValue; + + // Assert + Assert.Equal(encryptedValue, secretVersion.Value); + Assert.NotEmpty(secretVersion.Value); + } + + [Theory] + [BitAutoData] + public void SecretVersion_MultipleVersions_DifferentIds(List secretVersions, Guid secretId) + { + // Arrange & Act + foreach (var version in secretVersions) + { + version.SecretId = secretId; + version.SetNewId(); + } + + // Assert + var distinctIds = secretVersions.Select(v => v.Id).Distinct(); + Assert.Equal(secretVersions.Count, distinctIds.Count()); + Assert.All(secretVersions, v => Assert.Equal(secretId, v.SecretId)); + } + + [Theory] + [BitAutoData] + public void SecretVersion_VersionDateOrdering_Success(SecretVersion version1, SecretVersion version2, SecretVersion version3, Guid secretId) + { + // Arrange + var now = DateTime.UtcNow; + version1.SecretId = secretId; + version1.VersionDate = now.AddDays(-2); + + version2.SecretId = secretId; + version2.VersionDate = now.AddDays(-1); + + version3.SecretId = secretId; + version3.VersionDate = now; + + var versions = new List { version2, version3, version1 }; + + // Act + var orderedVersions = versions.OrderByDescending(v => v.VersionDate).ToList(); + + // Assert + Assert.Equal(version3.Id, orderedVersions[0].Id); // Most recent + Assert.Equal(version2.Id, orderedVersions[1].Id); + Assert.Equal(version1.Id, orderedVersions[2].Id); // Oldest + } +} diff --git a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs new file mode 100644 index 0000000000..b276174814 --- /dev/null +++ b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs @@ -0,0 +1,1267 @@ +using System.Reflection; +using System.Security.Claims; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.Auth.UserFeatures.Registration; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Tokens; +using Bit.Sso.Controllers; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; +using Duende.IdentityServer.Configuration; +using Duende.IdentityServer.Models; +using Duende.IdentityServer.Services; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using Xunit.Abstractions; +using AuthenticationOptions = Duende.IdentityServer.Configuration.AuthenticationOptions; + +namespace Bit.SSO.Test.Controllers; + +[ControllerCustomize(typeof(AccountController)), SutProviderCustomize] +public class AccountControllerTest +{ + private readonly ITestOutputHelper _output; + + public AccountControllerTest(ITestOutputHelper output) + { + _output = output; + } + + private static IAuthenticationService SetupHttpContextWithAuth( + SutProvider sutProvider, + AuthenticateResult authResult, + IAuthenticationService? authService = null) + { + var schemeProvider = Substitute.For(); + schemeProvider.GetDefaultAuthenticateSchemeAsync() + .Returns(new AuthenticationScheme("idsrv", "idsrv", typeof(IAuthenticationHandler))); + + var resolvedAuthService = authService ?? Substitute.For(); + resolvedAuthService.AuthenticateAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + .Returns(authResult); + + var services = new ServiceCollection(); + services.AddSingleton(resolvedAuthService); + services.AddSingleton(schemeProvider); + services.AddSingleton(new IdentityServerOptions + { + Authentication = new AuthenticationOptions + { + CookieAuthenticationScheme = "idsrv" + } + }); + var sp = services.BuildServiceProvider(); + + sutProvider.Sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext + { + RequestServices = sp + } + }; + + return resolvedAuthService; + } + + private static AuthenticateResult BuildSuccessfulExternalAuth(Guid orgId, string providerUserId, string email) + { + var claims = new[] + { + new Claim(JwtClaimTypes.Subject, providerUserId), + new Claim(JwtClaimTypes.Email, email) + }; + var principal = new ClaimsPrincipal(new ClaimsIdentity(claims, "External")); + var properties = new AuthenticationProperties + { + Items = + { + ["scheme"] = orgId.ToString(), + ["return_url"] = "~/", + ["state"] = "state", + ["user_identifier"] = string.Empty + } + }; + var ticket = new AuthenticationTicket(principal, properties, AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + return AuthenticateResult.Success(ticket); + } + + private static void ConfigureSsoAndUser( + SutProvider sutProvider, + Guid orgId, + string providerUserId, + User user, + Organization? organization = null, + OrganizationUser? orgUser = null) + { + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + if (organization != null) + { + organizationRepository.GetByIdAsync(orgId).Returns(organization); + } + if (organization != null && orgUser != null) + { + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(orgUser); + organizationUserRepository.GetManyByUserAsync(user.Id).Returns([orgUser]); + } + } + + private enum MeasurementScenario + { + ExistingSsoLinkedAccepted, + ExistingUserNoOrgUser, + JitProvision + } + + private sealed class LookupCounts + { + public int UserGetBySso { get; init; } + public int UserGetByEmail { get; init; } + public int OrgGetById { get; init; } + public int OrgUserGetByOrg { get; init; } + public int OrgUserGetByEmail { get; init; } + } + + private async Task MeasureCountsForScenarioAsync( + SutProvider sutProvider, + MeasurementScenario scenario, + bool preventNonCompliant) + { + var orgId = Guid.NewGuid(); + var providerUserId = $"meas-{scenario}-{(preventNonCompliant ? "on" : "off")}"; + var email = scenario == MeasurementScenario.JitProvision + ? "jit.compare@example.com" + : "existing.compare@example.com"; + + var organization = new Organization { Id = orgId, Name = "Org" }; + var user = new User { Id = Guid.NewGuid(), Email = email }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + // SSO config present + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var featureService = sutProvider.GetDependency(); + var interactionService = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + switch (scenario) + { + case MeasurementScenario.ExistingSsoLinkedAccepted: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id) + .Returns(new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }); + break; + case MeasurementScenario.ExistingUserNoOrgUser: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id) + .Returns((OrganizationUser?)null); + break; + case MeasurementScenario.JitProvision: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + userRepository.GetByEmailAsync(email).Returns((User?)null); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + break; + } + + featureService.IsEnabled(Arg.Any()).Returns(preventNonCompliant); + interactionService.GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // Ignore exceptions for measurement; some flows can throw based on status enforcement + } + + var counts = new LookupCounts + { + UserGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)), + UserGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)), + OrgGetById = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)), + OrgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)), + OrgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)), + }; + + userRepository.ClearReceivedCalls(); + organizationRepository.ClearReceivedCalls(); + organizationUserRepository.ClearReceivedCalls(); + + return counts; + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_NoOrgUser_ThrowsCouldNotFindOrganizationUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-missing-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "missing.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + // i18n returns the key so we can assert on message contents + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // SSO config + user link exists, but no org user membership + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser: null); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, user.Id).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("CouldNotFindOrganizationUser", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserInvited_AllowsLogin( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-invited-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "invited.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await authService.Received().SignOutAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserRevoked_ThrowsAccessRevoked( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-revoked-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "revoked.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Revoked, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("OrganizationUserAccessRevoked", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserUnknown_ThrowsUnknown( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-unknown-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "unknown.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var unknownStatus = (OrganizationUserStatusType)999; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = unknownStatus, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("OrganizationUserUnknownStatus", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_WithExistingUserAndAcceptedMembership_RedirectsToReturnUrl( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-123"; + var user = new User { Id = Guid.NewGuid(), Email = "user@example.com" }; + var organization = new Organization { Id = orgId, Name = "Test Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await authService.Received().SignOutAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + Arg.Any()); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_SkipsOrgLookupAndSignsIn( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-flag-off"; + var user = new User { Id = Guid.NewGuid(), Email = "flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetByOrganizationAsync(Guid.Empty, Guid.Empty); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingSsoLinkedAccepted_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing"; + var user = new User { Id = Guid.NewGuid(), Email = "existing@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email); + SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // ignore for measurement only + } + + // Assert (measurement only - no asserts on counts) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)) + + organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetManyByUserAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(0, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(1, orgUserGetByOrg); + Assert.Equal(0, orgUserGetByEmail); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_JitProvision_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-jit"; + var email = "jit.measure@example.com"; + var organization = new Organization { Id = orgId, Name = "Org", Seats = null }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + // JIT (no existing user or sso link) + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + userRepository.GetByEmailAsync(email).Returns((User?)null); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationEmailAsync(orgId, email).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // JIT path may throw due to Invited status under enforcement; ignore for measurement + } + + // Assert (measurement only - no asserts on counts) + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(1, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(0, orgUserGetByOrg); + Assert.Equal(1, orgUserGetByEmail); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + /// This test will trigger both the GetByOrganizationAsync and the fallback attempt to get by email + /// GetByOrganizationEmailAsync. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_NoOrgUser_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-no-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "existing2@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser: null); + + // Ensure orgUser lookup returns null + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, user.Id).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // ignore for measurement only + } + + // Assert (measurement only - no asserts on counts) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)) + + organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetManyByUserAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(0, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(1, orgUserGetByOrg); + Assert.Equal(1, orgUserGetByEmail); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_ExistingSsoLinkedAccepted_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-flagoff"; + var user = new User { Id = Guid.NewGuid(), Email = "existing.flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_ExistingUser_NoOrgUser_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-no-orguser-flagoff"; + var user = new User { Id = Guid.NewGuid(), Email = "existing2.flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_JitProvision_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-jit-flagoff"; + var email = "jit.flagoff@example.com"; + var organization = new Organization { Id = orgId, Name = "Org", Seats = null }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + // JIT (no existing user or sso link) + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + [Theory, BitAutoData] + public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingAcceptedUser_CreatesSsoLinkAndReturnsUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "provider-user-id"; + var email = "user@example.com"; + var existingUser = new User { Id = Guid.NewGuid(), Email = email }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = existingUser.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + // Arrange repository expectations for the flow + sutProvider.GetDependency().GetByEmailAsync(email).Returns(existingUser); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetManyByUserAsync(existingUser.Id) + .Returns(new List { orgUser }); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email).Returns(orgUser); + + // No existing SSO link so first SSO login event is logged + sutProvider.GetDependency().GetByUserIdOrganizationIdAsync(orgId, existingUser.Id).Returns((SsoUser?)null); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Jit User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var returned = await task; + + // Assert + Assert.Equal(existingUser.Id, returned.user.Id); + + await sutProvider.GetDependency().Received().CreateAsync(Arg.Is(s => + s.OrganizationId == orgId && s.UserId == existingUser.Id && s.ExternalId == providerUserId)); + + await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync( + orgUser, + EventType.OrganizationUser_FirstSsoLogin); + } + + [Theory, BitAutoData] + public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingInvitedUser_ThrowsAcceptInviteBeforeUsingSSO( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "provider-user-id"; + var email = "user@example.com"; + var existingUser = new User { Id = Guid.NewGuid(), Email = email, UsesKeyConnector = false }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = existingUser.Id, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }; + + // i18n returns the key so we can assert on message contents + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Arrange repository expectations for the flow + sutProvider.GetDependency().GetByEmailAsync(email).Returns(existingUser); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetManyByUserAsync(existingUser.Id) + .Returns(new List { orgUser }); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Invited User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + Assert + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var ex = await Assert.ThrowsAsync(async () => await task); + Assert.Equal("AcceptInviteBeforeUsingSSO", ex.Message); + } + + /// + /// PM-24579: Temporary comparison test to ensure the feature flag ON does not + /// regress lookup counts compared to OFF. When removing the flag, delete this + /// comparison test and keep the specific scenario snapshot tests if desired. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_Measurements_FlagOnVsOff_Comparisons( + SutProvider sutProvider) + { + // Arrange + var scenarios = new[] + { + MeasurementScenario.ExistingSsoLinkedAccepted, + MeasurementScenario.ExistingUserNoOrgUser, + MeasurementScenario.JitProvision + }; + + foreach (var scenario in scenarios) + { + // Act + var onCounts = await MeasureCountsForScenarioAsync(sutProvider, scenario, preventNonCompliant: true); + var offCounts = await MeasureCountsForScenarioAsync(sutProvider, scenario, preventNonCompliant: false); + + // Assert: off should not exceed on in any measured lookup type + Assert.True(offCounts.UserGetBySso <= onCounts.UserGetBySso, $"{scenario}: off UserGetBySso={offCounts.UserGetBySso} > on {onCounts.UserGetBySso}"); + Assert.True(offCounts.UserGetByEmail <= onCounts.UserGetByEmail, $"{scenario}: off UserGetByEmail={offCounts.UserGetByEmail} > on {onCounts.UserGetByEmail}"); + Assert.True(offCounts.OrgGetById <= onCounts.OrgGetById, $"{scenario}: off OrgGetById={offCounts.OrgGetById} > on {onCounts.OrgGetById}"); + Assert.True(offCounts.OrgUserGetByOrg <= onCounts.OrgUserGetByOrg, $"{scenario}: off OrgUserGetByOrg={offCounts.OrgUserGetByOrg} > on {onCounts.OrgUserGetByOrg}"); + Assert.True(offCounts.OrgUserGetByEmail <= onCounts.OrgUserGetByEmail, $"{scenario}: off OrgUserGetByEmail={offCounts.OrgUserGetByEmail} > on {onCounts.OrgUserGetByEmail}"); + + _output.WriteLine($"Scenario={scenario} | ON: SSO={onCounts.UserGetBySso}, Email={onCounts.UserGetByEmail}, Org={onCounts.OrgGetById}, OrgUserByOrg={onCounts.OrgUserGetByOrg}, OrgUserByEmail={onCounts.OrgUserGetByEmail}"); + _output.WriteLine($"Scenario={scenario} | OFF: SSO={offCounts.UserGetBySso}, Email={offCounts.UserGetByEmail}, Org={offCounts.OrgGetById}, OrgUserByOrg={offCounts.OrgUserGetByOrg}, OrgUserByEmail={offCounts.OrgUserGetByEmail}"); + } + } + + [Theory, BitAutoData] + public async Task AutoProvisionUserAsync_WithFeatureFlagEnabled_CallsRegisterSSOAutoProvisionedUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-new-user"; + var email = "newuser@example.com"; + var organization = new Organization { Id = orgId, Name = "Test Org", Seats = null }; + + // No existing user (JIT provisioning scenario) + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + + // Feature flag enabled + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + // Mock the RegisterSSOAutoProvisionedUserAsync to return success + sutProvider.GetDependency() + .RegisterSSOAutoProvisionedUserAsync(Arg.Any(), Arg.Any()) + .Returns(IdentityResult.Success); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "New User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method!.Invoke( + sutProvider.Sut, + new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var result = await task; + + // Assert + await sutProvider.GetDependency().Received(1) + .RegisterSSOAutoProvisionedUserAsync( + Arg.Is(u => u.Email == email && u.Name == "New User"), + Arg.Is(o => o.Id == orgId && o.Name == "Test Org")); + + Assert.NotNull(result.user); + Assert.Equal(email, result.user.Email); + Assert.Equal(organization.Id, result.organization.Id); + } + + [Theory, BitAutoData] + public async Task AutoProvisionUserAsync_WithFeatureFlagDisabled_CallsRegisterUserInstead( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-legacy-user"; + var email = "legacyuser@example.com"; + var organization = new Organization { Id = orgId, Name = "Test Org", Seats = null }; + + // No existing user (JIT provisioning scenario) + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + + // Feature flag disabled + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(false); + + // Mock the RegisterUser to return success + sutProvider.GetDependency() + .RegisterUser(Arg.Any()) + .Returns(IdentityResult.Success); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Legacy User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method!.Invoke( + sutProvider.Sut, + new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var result = await task; + + // Assert + await sutProvider.GetDependency().Received(1) + .RegisterUser(Arg.Is(u => u.Email == email && u.Name == "Legacy User")); + + // Verify the new method was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RegisterSSOAutoProvisionedUserAsync(Arg.Any(), Arg.Any()); + + Assert.NotNull(result.user); + Assert.Equal(email, result.user.Email); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMatchingOrgId_Succeeds( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var orgId = organization.Id; + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with matching org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock URL helper for IsLocalUrl check + var urlHelper = Substitute.For(); + urlHelper.IsLocalUrl(returnUrl).Returns(true); + sutProvider.Sut.Url = urlHelper; + + // Mock interaction service for IsValidReturnUrl check + var interactionService = sutProvider.GetDependency(); + interactionService.IsValidReturnUrl(returnUrl).Returns(true); + + // Act + var result = sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken); + + // Assert + var challengeResult = Assert.IsType(result); + Assert.Contains(scheme, challengeResult.AuthenticationSchemes); + Assert.NotNull(challengeResult.Properties); + Assert.Equal(scheme, challengeResult.Properties.Items["scheme"]); + Assert.Equal(returnUrl, challengeResult.Properties.Items["return_url"]); + Assert.Equal(state, challengeResult.Properties.Items["state"]); + Assert.Equal(userIdentifier, challengeResult.Properties.Items["user_identifier"]); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithMismatchedOrgId_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var correctOrgId = organization.Id; + var wrongOrgId = Guid.NewGuid(); + var scheme = wrongOrgId.ToString(); // Different from tokenable's org ID + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a tokenable with different org ID + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); // Contains correctOrgId + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSchemeFormat_ThrowsSsoOrganizationIdMismatch( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var scheme = "not-a-valid-guid"; + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "valid-sso-token"; + + // Mock the data protector to return a valid tokenable + var dataProtector = sutProvider.GetDependency>(); + var tokenable = new SsoTokenable(organization, 3600); + dataProtector.Unprotect(ssoToken).Returns(tokenable); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("SsoOrganizationIdMismatch", ex.Message); + } + + [Theory, BitAutoData] + public void ExternalChallenge_WithInvalidSsoToken_ThrowsInvalidSsoToken( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var scheme = orgId.ToString(); + var returnUrl = "~/vault"; + var state = "test-state"; + var userIdentifier = "user-123"; + var ssoToken = "invalid-corrupted-token"; + + // Mock the data protector to throw when trying to unprotect + var dataProtector = sutProvider.GetDependency>(); + dataProtector.Unprotect(ssoToken).Returns(_ => throw new Exception("Token validation failed")); + + // Mock i18n service to return the key + sutProvider.GetDependency() + .T(Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Act & Assert + var ex = Assert.Throws(() => + sutProvider.Sut.ExternalChallenge(scheme, returnUrl, state, userIdentifier, ssoToken)); + Assert.Equal("InvalidSsoToken", ex.Message); + } +} diff --git a/bitwarden_license/test/SSO.Test/SSO.Test.csproj b/bitwarden_license/test/SSO.Test/SSO.Test.csproj new file mode 100644 index 0000000000..4b509c9a50 --- /dev/null +++ b/bitwarden_license/test/SSO.Test/SSO.Test.csproj @@ -0,0 +1,35 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + diff --git a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs index 5f562a30c5..9ad231a63d 100644 --- a/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs +++ b/bitwarden_license/test/Scim.IntegrationTest/Controllers/v2/GroupsControllerTests.cs @@ -200,6 +200,38 @@ public class GroupsControllerTests : IClassFixture, IAsy AssertHelper.AssertPropertyEqual(expectedResponse, responseModel); } + [Fact] + public async Task GetList_SearchDisplayNameWithoutOptionalParameters_Success() + { + string filter = "displayName eq Test Group 2"; + int? itemsPerPage = null; + int? startIndex = null; + var expectedResponse = new ScimListResponseModel + { + ItemsPerPage = 50, //default value + TotalResults = 1, + StartIndex = 1, //default value + Resources = new List + { + new ScimGroupResponseModel + { + Id = ScimApplicationFactory.TestGroupId2, + DisplayName = "Test Group 2", + ExternalId = "B", + Schemas = new List { ScimConstants.Scim2SchemaGroup } + } + }, + Schemas = new List { ScimConstants.Scim2SchemaListResponse } + }; + + var context = await _factory.GroupsGetListAsync(ScimApplicationFactory.TestOrganizationId1, filter, itemsPerPage, startIndex); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + var responseModel = JsonSerializer.Deserialize>(context.Response.Body, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + AssertHelper.AssertPropertyEqual(expectedResponse, responseModel); + } + [Fact] public async Task Post_Success() { diff --git a/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json b/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json new file mode 100644 index 0000000000..496d0c075f --- /dev/null +++ b/bitwarden_license/test/Scim.IntegrationTest/appsettings.Development.json @@ -0,0 +1,36 @@ +{ + "globalSettings": { + "baseServiceUri": { + "vault": "https://localhost:8080", + "api": "http://localhost:4000", + "identity": "http://localhost:33656", + "admin": "http://localhost:62911", + "notifications": "http://localhost:61840", + "sso": "http://localhost:51822", + "internalNotifications": "http://localhost:61840", + "internalAdmin": "http://localhost:62911", + "internalIdentity": "http://localhost:33656", + "internalApi": "http://localhost:4000", + "internalVault": "https://localhost:8080", + "internalSso": "http://localhost:51822", + "internalScim": "http://localhost:44559" + }, + "mail": { + "smtp": { + "host": "localhost", + "port": 10250 + } + }, + "attachment": { + "connectionString": "UseDevelopmentStorage=true", + "baseUrl": "http://localhost:4000/attachments/" + }, + "events": { + "connectionString": "UseDevelopmentStorage=true" + }, + "storage": { + "connectionString": "UseDevelopmentStorage=true" + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" + } +} diff --git a/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs b/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs index 1599b6e390..b835e1fe6b 100644 --- a/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs +++ b/bitwarden_license/test/Scim.Test/Groups/GetGroupsListQueryTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Scim.Groups; +using Bit.Scim.Models; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -24,7 +25,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, null, count, startIndex); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Count = count, StartIndex = startIndex }); AssertHelper.AssertPropertyEqual(groups.Skip(startIndex - 1).Take(count).ToList(), result.groupList); AssertHelper.AssertPropertyEqual(groups.Count, result.totalResults); @@ -47,7 +48,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -67,7 +68,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -90,7 +91,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); @@ -112,7 +113,7 @@ public class GetGroupsListCommandTests .GetManyByOrganizationIdAsync(organizationId) .Returns(groups); - var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, filter, null, null); + var result = await sutProvider.Sut.GetGroupsListAsync(organizationId, new GetGroupsQueryParamModel { Filter = filter }); AssertHelper.AssertPropertyEqual(expectedGroupList, result.groupList); AssertHelper.AssertPropertyEqual(expectedTotalResults, result.totalResults); diff --git a/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs b/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs index 1b02e62970..8816885ea7 100644 --- a/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs @@ -436,7 +436,7 @@ public class PatchGroupCommandTests await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); // Assert: logging - sutProvider.GetDependency>().ReceivedWithAnyArgs().LogWarning(default); + sutProvider.GetDependency>().ReceivedWithAnyArgs().LogWarning(""); } [Theory] diff --git a/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs b/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs index 9352e5c202..7424b50c0d 100644 --- a/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/GetUsersListQueryTests.cs @@ -1,5 +1,6 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Scim.Models; using Bit.Scim.Users; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; diff --git a/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs b/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs index f391c93fe3..8b6c850c6f 100644 --- a/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/PatchUserCommandTests.cs @@ -1,6 +1,6 @@ using System.Text.Json; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs index ac23e7ecc1..eb8804cac5 100644 --- a/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Users/PostUserCommandTests.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -36,7 +37,7 @@ public class PostUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(organization).Returns(true); sutProvider.GetDependency() .InviteUserAsync(organizationId, diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index c5e42cf9e3..c82da051b4 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -57,7 +57,6 @@ services: mysql: image: mysql:8.0 - container_name: bw-mysql ports: - "3306:3306" command: @@ -88,7 +87,6 @@ services: idp: image: kenchan0130/simplesamlphp:1.19.8 - container_name: idp ports: - "8090:8080" environment: @@ -101,8 +99,7 @@ services: - idp rabbitmq: - image: rabbitmq:4.1.3-management - container_name: rabbitmq + image: rabbitmq:4.2.0-management ports: - "5672:5672" - "15672:15672" @@ -116,7 +113,6 @@ services: reverse-proxy: image: nginx:alpine - container_name: reverse-proxy volumes: - "./reverse-proxy.conf:/etc/nginx/conf.d/default.conf" ports: @@ -126,7 +122,6 @@ services: - proxy service-bus: - container_name: service-bus image: mcr.microsoft.com/azure-messaging/servicebus-emulator:latest pull_policy: always volumes: @@ -142,7 +137,6 @@ services: redis: image: redis:alpine - container_name: bw-redis ports: - "6379:6379" volumes: diff --git a/dev/generate_openapi_files.ps1 b/dev/generate_openapi_files.ps1 index 9eca7dc734..011319b3a3 100644 --- a/dev/generate_openapi_files.ps1 +++ b/dev/generate_openapi_files.ps1 @@ -18,11 +18,11 @@ if ($LASTEXITCODE -ne 0) { # Api internal & public Set-Location "../../src/Api" dotnet build -dotnet swagger tofile --output "../../api.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "internal" +dotnet swagger tofile --output "../../api.json" "./bin/Debug/net8.0/Api.dll" "internal" if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } -dotnet swagger tofile --output "../../api.public.json" --host "https://api.bitwarden.com" "./bin/Debug/net8.0/Api.dll" "public" +dotnet swagger tofile --output "../../api.public.json" "./bin/Debug/net8.0/Api.dll" "public" if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } diff --git a/dev/secrets.json.example b/dev/secrets.json.example index c6a16846e9..0d4213aec1 100644 --- a/dev/secrets.json.example +++ b/dev/secrets.json.example @@ -33,6 +33,10 @@ "id": "", "key": "" }, + "events": { + "connectionString": "", + "queueName": "event" + }, "licenseDirectory": "", "enableNewDeviceVerification": true, "enableEmailVerification": true diff --git a/dev/servicebusemulator_config.json b/dev/servicebusemulator_config.json index 294efc1897..bb50c0b1ee 100644 --- a/dev/servicebusemulator_config.json +++ b/dev/servicebusemulator_config.json @@ -3,22 +3,6 @@ "Namespaces": [ { "Name": "sbemulatorns", - "Queues": [ - { - "Name": "queue.1", - "Properties": { - "DeadLetteringOnMessageExpiration": false, - "DefaultMessageTimeToLive": "PT1H", - "DuplicateDetectionHistoryTimeWindow": "PT20S", - "ForwardDeadLetteredMessagesTo": "", - "ForwardTo": "", - "LockDuration": "PT1M", - "MaxDeliveryCount": 3, - "RequiresDuplicateDetection": false, - "RequiresSession": false - } - } - ], "Topics": [ { "Name": "event-logging", @@ -37,6 +21,9 @@ }, { "Name": "events-datadog-subscription" + }, + { + "Name": "events-teams-subscription" } ] }, @@ -98,6 +85,20 @@ } } ] + }, + { + "Name": "integration-teams-subscription", + "Rules": [ + { + "Name": "teams-integration-filter", + "Properties": { + "FilterType": "Correlation", + "CorrelationFilter": { + "Label": "teams" + } + } + } + ] } ] } diff --git a/dev/verify_migrations.ps1 b/dev/verify_migrations.ps1 new file mode 100644 index 0000000000..d63c34f2bd --- /dev/null +++ b/dev/verify_migrations.ps1 @@ -0,0 +1,132 @@ +#!/usr/bin/env pwsh + +<# +.SYNOPSIS + Validates that new database migration files follow naming conventions and chronological order. + +.DESCRIPTION + This script validates migration files in util/Migrator/DbScripts/ to ensure: + 1. New migrations follow the naming format: YYYY-MM-DD_NN_Description.sql + 2. New migrations are chronologically ordered (filename sorts after existing migrations) + 3. Dates use leading zeros (e.g., 2025-01-05, not 2025-1-5) + 4. A 2-digit sequence number is included (e.g., _00, _01) + +.PARAMETER BaseRef + The base git reference to compare against (e.g., 'main', 'HEAD~1') + +.PARAMETER CurrentRef + The current git reference (defaults to 'HEAD') + +.EXAMPLE + # For pull requests - compare against main branch + .\verify_migrations.ps1 -BaseRef main + +.EXAMPLE + # For pushes - compare against previous commit + .\verify_migrations.ps1 -BaseRef HEAD~1 +#> + +param( + [Parameter(Mandatory = $true)] + [string]$BaseRef, + + [Parameter(Mandatory = $false)] + [string]$CurrentRef = "HEAD" +) + +# Use invariant culture for consistent string comparison +[System.Threading.Thread]::CurrentThread.CurrentCulture = [System.Globalization.CultureInfo]::InvariantCulture + +$migrationPath = "util/Migrator/DbScripts" + +# Get list of migrations from base reference +try { + $baseMigrations = git ls-tree -r --name-only $BaseRef -- "$migrationPath/*.sql" 2>$null | Sort-Object + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Could not retrieve migrations from base reference '$BaseRef'" + $baseMigrations = @() + } +} +catch { + Write-Host "Warning: Could not retrieve migrations from base reference '$BaseRef'" + $baseMigrations = @() +} + +# Get list of migrations from current reference +$currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$migrationPath/*.sql" | Sort-Object + +# Find added migrations +$addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } + +if ($addedMigrations.Count -eq 0) { + Write-Host "No new migration files added." + exit 0 +} + +Write-Host "New migration files detected:" +$addedMigrations | ForEach-Object { Write-Host " $_" } +Write-Host "" + +# Get the last migration from base reference +if ($baseMigrations.Count -eq 0) { + Write-Host "No previous migrations found (initial commit?). Skipping validation." + exit 0 +} + +$lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) +Write-Host "Last migration in base reference: $lastBaseMigration" +Write-Host "" + +# Required format regex: YYYY-MM-DD_NN_Description.sql +$formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' + +$validationFailed = $false + +foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Validate NEW migration filename format + if ($migrationName -notmatch $formatRegex) { + Write-Host "ERROR: Migration '$migrationName' does not match required format" + Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" + Write-Host " - YYYY: 4-digit year" + Write-Host " - MM: 2-digit month with leading zero (01-12)" + Write-Host " - DD: 2-digit day with leading zero (01-31)" + Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" + Write-Host "Example: 2025-01-15_00_MyMigration.sql" + $validationFailed = $true + continue + } + + # Compare migration name with last base migration (using ordinal string comparison) + if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" + $validationFailed = $true + } + else { + Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + } +} + +Write-Host "" + +if ($validationFailed) { + Write-Host "FAILED: One or more migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new migration files must:" + Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" + Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" + Write-Host " 4. Have a filename that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" + Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 3. Ensure the date is after $lastBaseMigration" + Write-Host "" + Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + exit 1 +} + +Write-Host "SUCCESS: All new migrations are correctly named and in chronological order" +exit 0 diff --git a/global.json b/global.json index d25197db39..4cbe3f083a 100644 --- a/global.json +++ b/global.json @@ -5,6 +5,7 @@ }, "msbuild-sdks": { "Microsoft.Build.Traversal": "4.1.0", - "Microsoft.Build.Sql": "1.0.0" + "Microsoft.Build.Sql": "1.0.0", + "Bitwarden.Server.Sdk": "1.2.0" } } diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 2417bf610d..cd370e3898 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -14,8 +14,10 @@ using Bit.Core.AdminConsole.Providers.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Models.OrganizationConnectionConfigs; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -41,7 +43,7 @@ public class OrganizationsController : Controller private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; private readonly IPolicyRepository _policyRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IApplicationCacheService _applicationCacheService; private readonly GlobalSettings _globalSettings; private readonly IProviderRepository _providerRepository; @@ -56,6 +58,7 @@ public class OrganizationsController : Controller private readonly IOrganizationInitiateDeleteCommand _organizationInitiateDeleteCommand; private readonly IPricingClient _pricingClient; private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand; + private readonly IOrganizationBillingService _organizationBillingService; public OrganizationsController( IOrganizationRepository organizationRepository, @@ -66,7 +69,7 @@ public class OrganizationsController : Controller ICollectionRepository collectionRepository, IGroupRepository groupRepository, IPolicyRepository policyRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IApplicationCacheService applicationCacheService, GlobalSettings globalSettings, IProviderRepository providerRepository, @@ -80,7 +83,8 @@ public class OrganizationsController : Controller IProviderBillingService providerBillingService, IOrganizationInitiateDeleteCommand organizationInitiateDeleteCommand, IPricingClient pricingClient, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IOrganizationBillingService organizationBillingService) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -105,6 +109,7 @@ public class OrganizationsController : Controller _organizationInitiateDeleteCommand = organizationInitiateDeleteCommand; _pricingClient = pricingClient; _resendOrganizationInviteCommand = resendOrganizationInviteCommand; + _organizationBillingService = organizationBillingService; } [RequirePermission(Permission.Org_List_View)] @@ -241,6 +246,8 @@ public class OrganizationsController : Controller var existingOrganizationData = new Organization { Id = organization.Id, + Name = organization.Name, + BillingEmail = organization.BillingEmail, Status = organization.Status, PlanType = organization.PlanType, Seats = organization.Seats @@ -286,6 +293,22 @@ public class OrganizationsController : Controller await _applicationCacheService.UpsertOrganizationAbilityAsync(organization); + // Sync name/email changes to Stripe + if (existingOrganizationData.Name != organization.Name || existingOrganizationData.BillingEmail != organization.BillingEmail) + { + try + { + await _organizationBillingService.UpdateOrganizationNameAndEmail(organization); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for organization {OrganizationId}. Database was updated successfully.", + organization.Id); + TempData["Warning"] = "Organization updated successfully, but Stripe customer name/email synchronization failed."; + } + } + return RedirectToAction("Edit", new { id }); } @@ -472,6 +495,8 @@ public class OrganizationsController : Controller organization.UseRiskInsights = model.UseRiskInsights; organization.UseOrganizationDomains = model.UseOrganizationDomains; organization.UseAdminSponsoredFamilies = model.UseAdminSponsoredFamilies; + organization.UseAutomaticUserConfirmation = model.UseAutomaticUserConfirmation; + organization.UsePhishingBlocker = model.UsePhishingBlocker; //secrets organization.SmSeats = model.SmSeats; diff --git a/src/Admin/AdminConsole/Controllers/ProvidersController.cs b/src/Admin/AdminConsole/Controllers/ProvidersController.cs index 9344179a77..d9135e1d1c 100644 --- a/src/Admin/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Admin/AdminConsole/Controllers/ProvidersController.cs @@ -56,6 +56,7 @@ public class ProvidersController : Controller private readonly IStripeAdapter _stripeAdapter; private readonly IAccessControlService _accessControlService; private readonly ISubscriberService _subscriberService; + private readonly ILogger _logger; public ProvidersController(IOrganizationRepository organizationRepository, IResellerClientOrganizationSignUpCommand resellerClientOrganizationSignUpCommand, @@ -72,7 +73,8 @@ public class ProvidersController : Controller IPricingClient pricingClient, IStripeAdapter stripeAdapter, IAccessControlService accessControlService, - ISubscriberService subscriberService) + ISubscriberService subscriberService, + ILogger logger) { _organizationRepository = organizationRepository; _resellerClientOrganizationSignUpCommand = resellerClientOrganizationSignUpCommand; @@ -92,6 +94,7 @@ public class ProvidersController : Controller _braintreeMerchantUrl = webHostEnvironment.GetBraintreeMerchantUrl(); _braintreeMerchantId = globalSettings.Braintree.MerchantId; _subscriberService = subscriberService; + _logger = logger; } [RequirePermission(Permission.Provider_List_View)] @@ -296,6 +299,9 @@ public class ProvidersController : Controller var originalProviderStatus = provider.Enabled; + // Capture original billing email before modifications for Stripe sync + var originalBillingEmail = provider.BillingEmail; + model.ToProvider(provider); // validate the stripe ids to prevent saving a bad one @@ -321,6 +327,22 @@ public class ProvidersController : Controller await _providerService.UpdateAsync(provider); await _applicationCacheService.UpsertProviderAbilityAsync(provider); + // Sync billing email changes to Stripe + if (!string.IsNullOrEmpty(provider.GatewayCustomerId) && originalBillingEmail != provider.BillingEmail) + { + try + { + await _providerBillingService.UpdateProviderNameAndEmail(provider); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.", + provider.Id); + TempData["Warning"] = "Provider updated successfully, but Stripe customer email synchronization failed."; + } + } + if (!provider.IsBillable()) { return RedirectToAction("Edit", new { id }); @@ -339,11 +361,11 @@ public class ProvidersController : Controller ]); await _providerBillingService.UpdateSeatMinimums(updateMspSeatMinimumsCommand); - var customer = await _stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId); if (model.PayByInvoice != customer.ApprovedToPayByInvoice()) { var approvedToPayByInvoice = model.PayByInvoice ? "1" : "0"; - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = new Dictionary { diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs index b64af3135f..4fff85e1e8 100644 --- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs @@ -106,6 +106,9 @@ public class OrganizationEditModel : OrganizationViewModel SmServiceAccounts = org.SmServiceAccounts; MaxAutoscaleSmServiceAccounts = org.MaxAutoscaleSmServiceAccounts; UseOrganizationDomains = org.UseOrganizationDomains; + UseAutomaticUserConfirmation = org.UseAutomaticUserConfirmation; + UsePhishingBlocker = org.UsePhishingBlocker; + _plans = plans; } @@ -158,6 +161,8 @@ public class OrganizationEditModel : OrganizationViewModel public new bool UseSecretsManager { get; set; } [Display(Name = "Risk Insights")] public new bool UseRiskInsights { get; set; } + [Display(Name = "Phishing Blocker")] + public new bool UsePhishingBlocker { get; set; } [Display(Name = "Admin Sponsored Families")] public bool UseAdminSponsoredFamilies { get; set; } [Display(Name = "Self Host")] @@ -192,6 +197,8 @@ public class OrganizationEditModel : OrganizationViewModel [Display(Name = "Use Organization Domains")] public bool UseOrganizationDomains { get; set; } + [Display(Name = "Automatic User Confirmation")] + public bool UseAutomaticUserConfirmation { get; set; } /** * Creates a Plan[] object for use in Javascript * This is mapped manually below to provide some type safety in case the plan objects change @@ -231,6 +238,7 @@ public class OrganizationEditModel : OrganizationViewModel LegacyYear = p.LegacyYear, Disabled = p.Disabled, SupportsSecretsManager = p.SupportsSecretsManager, + AutomaticUserConfirmation = p.AutomaticUserConfirmation, PasswordManager = new { @@ -322,6 +330,7 @@ public class OrganizationEditModel : OrganizationViewModel existingOrganization.SmServiceAccounts = SmServiceAccounts; existingOrganization.MaxAutoscaleSmServiceAccounts = MaxAutoscaleSmServiceAccounts; existingOrganization.UseOrganizationDomains = UseOrganizationDomains; + existingOrganization.UsePhishingBlocker = UsePhishingBlocker; return existingOrganization; } } diff --git a/src/Admin/AdminConsole/Models/OrganizationViewModel.cs b/src/Admin/AdminConsole/Models/OrganizationViewModel.cs index 2c126ecd8e..457686be53 100644 --- a/src/Admin/AdminConsole/Models/OrganizationViewModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationViewModel.cs @@ -75,6 +75,7 @@ public class OrganizationViewModel public int OccupiedSmSeatsCount { get; set; } public bool UseSecretsManager => Organization.UseSecretsManager; public bool UseRiskInsights => Organization.UseRiskInsights; + public bool UsePhishingBlocker => Organization.UsePhishingBlocker; public IEnumerable OwnersDetails { get; set; } public IEnumerable AdminsDetails { get; set; } } diff --git a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml index 267264a38f..b22859ed60 100644 --- a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml +++ b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml @@ -152,11 +152,19 @@ - @if(FeatureService.IsEnabled(FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) +
+ + +
+
+ + +
+ @if(FeatureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) {
- - + +
} diff --git a/src/Admin/Billing/Controllers/MigrateProvidersController.cs b/src/Admin/Billing/Controllers/MigrateProvidersController.cs deleted file mode 100644 index ef5ea2312e..0000000000 --- a/src/Admin/Billing/Controllers/MigrateProvidersController.cs +++ /dev/null @@ -1,83 +0,0 @@ -using Bit.Admin.Billing.Models; -using Bit.Admin.Enums; -using Bit.Admin.Utilities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Migration.Services; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Admin.Billing.Controllers; - -[Authorize] -[Route("migrate-providers")] -[SelfHosted(NotSelfHostedOnly = true)] -public class MigrateProvidersController( - IProviderMigrator providerMigrator) : Controller -{ - [HttpGet] - [RequirePermission(Permission.Tools_MigrateProviders)] - public IActionResult Index() - { - return View(new MigrateProvidersRequestModel()); - } - - [HttpPost] - [RequirePermission(Permission.Tools_MigrateProviders)] - [ValidateAntiForgeryToken] - public async Task PostAsync(MigrateProvidersRequestModel request) - { - var providerIds = GetProviderIdsFromInput(request.ProviderIds); - - if (providerIds.Count == 0) - { - return RedirectToAction("Index"); - } - - foreach (var providerId in providerIds) - { - await providerMigrator.Migrate(providerId); - } - - return RedirectToAction("Results", new { ProviderIds = string.Join("\r\n", providerIds) }); - } - - [HttpGet("results")] - [RequirePermission(Permission.Tools_MigrateProviders)] - public async Task ResultsAsync(MigrateProvidersRequestModel request) - { - var providerIds = GetProviderIdsFromInput(request.ProviderIds); - - if (providerIds.Count == 0) - { - return View(Array.Empty()); - } - - var results = await Task.WhenAll(providerIds.Select(providerMigrator.GetResult)); - - return View(results); - } - - [HttpGet("results/{providerId:guid}")] - [RequirePermission(Permission.Tools_MigrateProviders)] - public async Task DetailsAsync([FromRoute] Guid providerId) - { - var result = await providerMigrator.GetResult(providerId); - - if (result == null) - { - return RedirectToAction("Index"); - } - - return View(result); - } - - private static List GetProviderIdsFromInput(string text) => !string.IsNullOrEmpty(text) - ? text.Split( - ["\r\n", "\r", "\n"], - StringSplitOptions.TrimEntries - ) - .Select(id => new Guid(id)) - .ToList() - : []; -} diff --git a/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs b/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs deleted file mode 100644 index 273f934eba..0000000000 --- a/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs +++ /dev/null @@ -1,13 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Admin.Billing.Models; - -public class MigrateProvidersRequestModel -{ - [Required] - [Display(Name = "Provider IDs")] - public string ProviderIds { get; set; } -} diff --git a/src/Admin/Billing/Views/MigrateProviders/Details.cshtml b/src/Admin/Billing/Views/MigrateProviders/Details.cshtml deleted file mode 100644 index 6ee0344057..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Details.cshtml +++ /dev/null @@ -1,39 +0,0 @@ -@using System.Text.Json -@model Bit.Core.Billing.Providers.Migration.Models.ProviderMigrationResult -@{ - ViewData["Title"] = "Results"; -} - -

Migrate Providers

-

Migration Details: @Model.ProviderName

-
-
Id
-
@Model.ProviderId
- -
Result
-
@Model.Result
-
-

Client Organizations

-
- - - - - - - - - - - @foreach (var clientResult in Model.Clients) - { - - - - - - - } - -
IDNameResultPrevious State
@clientResult.OrganizationId@clientResult.OrganizationName@clientResult.Result
@Html.Raw(JsonSerializer.Serialize(clientResult.PreviousState))
-
diff --git a/src/Admin/Billing/Views/MigrateProviders/Index.cshtml b/src/Admin/Billing/Views/MigrateProviders/Index.cshtml deleted file mode 100644 index 0aed94c25d..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Index.cshtml +++ /dev/null @@ -1,46 +0,0 @@ -@model Bit.Admin.Billing.Models.MigrateProvidersRequestModel; -@{ - ViewData["Title"] = "Migrate Providers"; -} - -

Migrate Providers

-

Bulk Consolidated Billing Migration Tool

-
-

- This tool allows you to provide a list of IDs for Providers that you would like to migrate to Consolidated Billing. - Because of the expensive nature of the operation, you can only migrate 10 Providers at a time. -

-

- Updates made through this tool are irreversible without manual intervention. -

-

Example Input (Please enter each Provider ID separated by a new line):

-
-
-
f513affc-2290-4336-879e-21ec3ecf3e78
-f7a5cb0d-4b74-445c-8d8c-232d1d32bbe2
-bf82d3cf-0e21-4f39-b81b-ef52b2fc6a3a
-174e82fc-70c3-448d-9fe7-00bad2a3ab00
-22a4bbbf-58e3-4e4c-a86a-a0d7caf4ff14
-
-
-
-
-
- - -
-
- -
-
-
-
-
- - -
-
- -
-
-
diff --git a/src/Admin/Billing/Views/MigrateProviders/Results.cshtml b/src/Admin/Billing/Views/MigrateProviders/Results.cshtml deleted file mode 100644 index 94db08db3d..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Results.cshtml +++ /dev/null @@ -1,28 +0,0 @@ -@model Bit.Core.Billing.Providers.Migration.Models.ProviderMigrationResult[] -@{ - ViewData["Title"] = "Results"; -} - -

Migrate Providers

-

Results

-
- - - - - - - - - - @foreach (var result in Model) - { - - - - - - } - -
IDNameResult
@result.ProviderId@result.ProviderName@result.Result
-
diff --git a/src/Admin/Controllers/HomeController.cs b/src/Admin/Controllers/HomeController.cs index debe5979f5..5b36032ec9 100644 --- a/src/Admin/Controllers/HomeController.cs +++ b/src/Admin/Controllers/HomeController.cs @@ -61,7 +61,7 @@ public class HomeController : Controller } catch (HttpRequestException e) { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + _logger.LogError(e, "Error encountered while sending GET request to {RequestUri}", requestUri); return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; } @@ -83,7 +83,7 @@ public class HomeController : Controller } catch (HttpRequestException e) { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + _logger.LogError(e, "Error encountered while sending GET request to {RequestUri}", requestUri); return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; } diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index b754b1f968..2dd6de89a0 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -1,7 +1,6 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using System.Text; using System.Text.Json; using Bit.Admin.Enums; using Bit.Admin.Models; @@ -9,8 +8,8 @@ using Bit.Admin.Utilities; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Models.BitStripe; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.Services; @@ -33,7 +32,6 @@ public class ToolsController : Controller private readonly IInstallationRepository _installationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IProviderUserRepository _providerUserRepository; - private readonly IPaymentService _paymentService; private readonly IStripeAdapter _stripeAdapter; private readonly IWebHostEnvironment _environment; @@ -46,7 +44,6 @@ public class ToolsController : Controller IInstallationRepository installationRepository, IOrganizationUserRepository organizationUserRepository, IProviderUserRepository providerUserRepository, - IPaymentService paymentService, IStripeAdapter stripeAdapter, IWebHostEnvironment environment) { @@ -58,7 +55,6 @@ public class ToolsController : Controller _installationRepository = installationRepository; _organizationUserRepository = organizationUserRepository; _providerUserRepository = providerUserRepository; - _paymentService = paymentService; _stripeAdapter = stripeAdapter; _environment = environment; } @@ -341,138 +337,4 @@ public class ToolsController : Controller throw new Exception("No license to generate."); } } - - [RequirePermission(Permission.Tools_ManageStripeSubscriptions)] - public async Task StripeSubscriptions(StripeSubscriptionListOptions options) - { - options = options ?? new StripeSubscriptionListOptions(); - options.Limit = 10; - options.Expand = new List() { "data.customer", "data.latest_invoice" }; - options.SelectAll = false; - - var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); - - options.StartingAfter = subscriptions.LastOrDefault()?.Id; - options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? - subscriptions.FirstOrDefault()?.Id : - null; - - var isProduction = _environment.IsProduction(); - var model = new StripeSubscriptionsModel() - { - Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), - Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, - TestClocks = isProduction ? new List() : await _stripeAdapter.TestClockListAsync(), - Filter = options - }; - return View(model); - } - - [HttpPost] - [RequirePermission(Permission.Tools_ManageStripeSubscriptions)] - public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) - { - if (!ModelState.IsValid) - { - var isProduction = _environment.IsProduction(); - model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; - model.TestClocks = isProduction ? new List() : await _stripeAdapter.TestClockListAsync(); - return View(model); - } - - if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) - { - var subscriptions = model.Filter.SelectAll ? - await _stripeAdapter.SubscriptionListAsync(model.Filter) : - model.Items.Where(x => x.Selected).Select(x => x.Subscription); - - if (model.Action == StripeSubscriptionsAction.Export) - { - return StripeSubscriptionsExport(subscriptions); - } - - if (model.Action == StripeSubscriptionsAction.BulkCancel) - { - await StripeSubscriptionsCancel(subscriptions); - } - } - else - { - if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.StartingAfter = null; - } - - if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) - { - if (!string.IsNullOrEmpty(model.Filter.StartingAfter)) - { - var subscription = await _stripeAdapter.SubscriptionGetAsync(model.Filter.StartingAfter); - if (subscription.Status == "canceled") - { - model.Filter.StartingAfter = null; - } - } - model.Filter.EndingBefore = null; - } - } - - - return RedirectToAction("StripeSubscriptions", model.Filter); - } - - // This requires a redundant API call to Stripe because of the way they handle pagination. - // The StartingBefore value has to be inferred from the list we get, and isn't supplied by Stripe. - private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) - { - var hasPreviousPage = false; - if (subscriptions.FirstOrDefault()?.Id != null) - { - var previousPageSearchOptions = new StripeSubscriptionListOptions() - { - EndingBefore = subscriptions.FirstOrDefault().Id, - Limit = 1, - Status = options.Status, - CurrentPeriodEndDate = options.CurrentPeriodEndDate, - CurrentPeriodEndRange = options.CurrentPeriodEndRange, - Price = options.Price - }; - hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; - } - return hasPreviousPage; - } - - private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) - { - foreach (var s in subscriptions) - { - await _stripeAdapter.SubscriptionCancelAsync(s.Id); - if (s.LatestInvoice?.Status == "open") - { - await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); - } - } - } - - private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) - { - var fieldsToExport = subscriptions.Select(s => new - { - StripeId = s.Id, - CustomerEmail = s.Customer?.Email, - SubscriptionStatus = s.Status, - InvoiceDueDate = s.CurrentPeriodEnd, - SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) - }); - - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - WriteIndented = true - }; - - var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); - var bytes = Encoding.UTF8.GetBytes(result); - return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); - } } diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index b85a91719c..f42b22b098 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -5,6 +5,7 @@ using Bit.Admin.Models; using Bit.Admin.Services; using Bit.Admin.Utilities; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -20,7 +21,7 @@ public class UsersController : Controller { private readonly IUserRepository _userRepository; private readonly ICipherRepository _cipherRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly GlobalSettings _globalSettings; private readonly IAccessControlService _accessControlService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; @@ -30,7 +31,7 @@ public class UsersController : Controller public UsersController( IUserRepository userRepository, ICipherRepository cipherRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, GlobalSettings globalSettings, IAccessControlService accessControlService, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, diff --git a/src/Admin/Dockerfile b/src/Admin/Dockerfile index 648ff1be91..84248639cf 100644 --- a/src/Admin/Dockerfile +++ b/src/Admin/Dockerfile @@ -1,7 +1,7 @@ ############################################### # Node.js build stage # ############################################### -FROM node:20-alpine3.21 AS node-build +FROM --platform=$BUILDPLATFORM node:20-alpine3.21 AS node-build WORKDIR /app COPY src/Admin/package*.json ./ diff --git a/src/Admin/Enums/Permissions.cs b/src/Admin/Enums/Permissions.cs index 14b255b2b6..34d975226e 100644 --- a/src/Admin/Enums/Permissions.cs +++ b/src/Admin/Enums/Permissions.cs @@ -52,8 +52,6 @@ public enum Permission Tools_PromoteProviderServiceUser, Tools_GenerateLicenseFile, Tools_ManageTaxRates, - Tools_ManageStripeSubscriptions, Tools_CreateEditTransaction, - Tools_ProcessStripeEvents, - Tools_MigrateProviders + Tools_ProcessStripeEvents } diff --git a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs index 434c265f26..219e6846bd 100644 --- a/src/Admin/HostedServices/DatabaseMigrationHostedService.cs +++ b/src/Admin/HostedServices/DatabaseMigrationHostedService.cs @@ -19,7 +19,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable public virtual async Task StartAsync(CancellationToken cancellationToken) { // Wait 20 seconds to allow database to come online - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); var maxMigrationAttempts = 10; for (var i = 1; i <= maxMigrationAttempts; i++) @@ -41,7 +41,7 @@ public class DatabaseMigrationHostedService : IHostedService, IDisposable { _logger.LogError(e, "Database unavailable for migration. Trying again (attempt #{0})...", i + 1); - await Task.Delay(20000); + await Task.Delay(20000, cancellationToken); } } } diff --git a/src/Admin/Jobs/AliveJob.cs b/src/Admin/Jobs/AliveJob.cs index b97d597e58..d62f4cc2cc 100644 --- a/src/Admin/Jobs/AliveJob.cs +++ b/src/Admin/Jobs/AliveJob.cs @@ -22,7 +22,7 @@ public class AliveJob : BaseJob { _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, {StatusCode}", response.StatusCode); } } diff --git a/src/Admin/Models/StripeSubscriptionsModel.cs b/src/Admin/Models/StripeSubscriptionsModel.cs deleted file mode 100644 index 36e1f099e1..0000000000 --- a/src/Admin/Models/StripeSubscriptionsModel.cs +++ /dev/null @@ -1,45 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Core.Models.BitStripe; - -namespace Bit.Admin.Models; - -public class StripeSubscriptionRowModel -{ - public Stripe.Subscription Subscription { get; set; } - public bool Selected { get; set; } - - public StripeSubscriptionRowModel() { } - public StripeSubscriptionRowModel(Stripe.Subscription subscription) - { - Subscription = subscription; - } -} - -public enum StripeSubscriptionsAction -{ - Search, - PreviousPage, - NextPage, - Export, - BulkCancel -} - -public class StripeSubscriptionsModel : IValidatableObject -{ - public List Items { get; set; } - public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; - public string Message { get; set; } - public List Prices { get; set; } - public List TestClocks { get; set; } - public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); - public IEnumerable Validate(ValidationContext validationContext) - { - if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") - { - yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); - } - } -} diff --git a/src/Admin/Program.cs b/src/Admin/Program.cs index 05bf35d41d..006a8223b2 100644 --- a/src/Admin/Program.cs +++ b/src/Admin/Program.cs @@ -16,19 +16,8 @@ public class Program o.Limits.MaxRequestLineSize = 20_000; }); webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - return e.Level >= globalSettings.MinLogLevel.AdminSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index 5b34e13f6c..87d68a7ac6 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -10,7 +10,6 @@ using Microsoft.AspNetCore.Mvc.Razor; using Microsoft.Extensions.DependencyInjection.Extensions; using Bit.Admin.Services; using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Providers.Migration; #if !OSS using Bit.Commercial.Core.Utilities; @@ -92,7 +91,6 @@ public class Startup services.AddDistributedCache(globalSettings); services.AddBillingOperations(); services.AddHttpClient(); - services.AddProviderMigration(); #if OSS services.AddOosServices(); @@ -134,11 +132,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Admin/Utilities/RolePermissionMapping.cs b/src/Admin/Utilities/RolePermissionMapping.cs index b60cf895a1..6dddc4ffeb 100644 --- a/src/Admin/Utilities/RolePermissionMapping.cs +++ b/src/Admin/Utilities/RolePermissionMapping.cs @@ -52,8 +52,7 @@ public static class RolePermissionMapping Permission.Tools_PromoteAdmin, Permission.Tools_PromoteProviderServiceUser, Permission.Tools_GenerateLicenseFile, - Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions + Permission.Tools_ManageTaxRates } }, { "admin", new List @@ -105,7 +104,6 @@ public static class RolePermissionMapping Permission.Tools_PromoteProviderServiceUser, Permission.Tools_GenerateLicenseFile, Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions, Permission.Tools_CreateEditTransaction } }, @@ -180,10 +178,8 @@ public static class RolePermissionMapping Permission.Tools_ChargeBrainTreeCustomer, Permission.Tools_GenerateLicenseFile, Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions, Permission.Tools_CreateEditTransaction, - Permission.Tools_ProcessStripeEvents, - Permission.Tools_MigrateProviders + Permission.Tools_ProcessStripeEvents } }, { "sales", new List diff --git a/src/Admin/Views/Shared/_Layout.cshtml b/src/Admin/Views/Shared/_Layout.cshtml index 1661a8bbc3..c13be428b4 100644 --- a/src/Admin/Views/Shared/_Layout.cshtml +++ b/src/Admin/Views/Shared/_Layout.cshtml @@ -13,12 +13,10 @@ var canPromoteAdmin = AccessControlService.UserHasPermission(Permission.Tools_PromoteAdmin); var canPromoteProviderServiceUser = AccessControlService.UserHasPermission(Permission.Tools_PromoteProviderServiceUser); var canGenerateLicense = AccessControlService.UserHasPermission(Permission.Tools_GenerateLicenseFile); - var canManageStripeSubscriptions = AccessControlService.UserHasPermission(Permission.Tools_ManageStripeSubscriptions); var canProcessStripeEvents = AccessControlService.UserHasPermission(Permission.Tools_ProcessStripeEvents); - var canMigrateProviders = AccessControlService.UserHasPermission(Permission.Tools_MigrateProviders); var canViewTools = canChargeBraintree || canCreateTransaction || canPromoteAdmin || canPromoteProviderServiceUser || - canGenerateLicense || canManageStripeSubscriptions; + canGenerateLicense; } @@ -102,12 +100,6 @@ Generate License - } - @if (canManageStripeSubscriptions) - { - - Manage Stripe Subscriptions - } @if (canProcessStripeEvents) { @@ -115,12 +107,6 @@ Process Stripe Events } - @if (canMigrateProviders) - { - - Migrate Providers - - } } diff --git a/src/Admin/Views/Tools/StripeSubscriptions.cshtml b/src/Admin/Views/Tools/StripeSubscriptions.cshtml deleted file mode 100644 index d8c168b3b0..0000000000 --- a/src/Admin/Views/Tools/StripeSubscriptions.cshtml +++ /dev/null @@ -1,277 +0,0 @@ -@model StripeSubscriptionsModel -@{ - ViewData["Title"] = "Stripe Subscriptions"; -} - -@section Scripts { - -} - -

Manage Stripe Subscriptions

-@if (!string.IsNullOrWhiteSpace(Model.Message)) -{ -
-} -
-
-
-
- - -
-
- -
-
-
- - -
-
- - -
-
- @{ - var date = @Model.Filter.CurrentPeriodEndDate.HasValue ? @Model.Filter.CurrentPeriodEndDate.Value.ToString("yyyy-MM-dd") : string.Empty; - } - -
-
-
- - -
-
- - -
-
- -
-
-
- -
-
- All @Model.Items.Count subscriptions on this page are selected.
- - - All subscriptions for this search are selected. - - -
-
-
- - - - - - - - - - - - - @if (!Model.Items.Any()) - { - - - - } - else - { - @for (var i = 0; i < Model.Items.Count; i++) - { - - - - - - - - - } - } - -
-
- -
-
IdCustomer EmailStatusProduct TierCurrent Period End
No results to list.
- - @{ - var i0 = i; - } - - - - - - - - @for (var j = 0; j < Model.Items[i].Subscription.Items.Data.Count; j++) - { - var i1 = i; - var j1 = j; - - } -
- - @{ - var i2 = i; - } - -
-
- @Model.Items[i].Subscription.Id - - @Model.Items[i].Subscription.Customer?.Email - - @Model.Items[i].Subscription.Status - - @string.Join(",", Model.Items[i].Subscription.Items.Data.Select(product => product.Plan.Id).ToArray()) - - @Model.Items[i].Subscription.CurrentPeriodEnd.ToShortDateString() -
-
- -
diff --git a/src/Admin/appsettings.Development.json b/src/Admin/appsettings.Development.json index 861f9be98d..15d61f493f 100644 --- a/src/Admin/appsettings.Development.json +++ b/src/Admin/appsettings.Development.json @@ -27,6 +27,7 @@ }, "storage": { "connectionString": "UseDevelopmentStorage=true" - } + }, + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/src/Admin/package-lock.json b/src/Admin/package-lock.json index 2e3a335598..6e0f78e1e6 100644 --- a/src/Admin/package-lock.json +++ b/src/Admin/package-lock.json @@ -18,9 +18,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } }, @@ -679,6 +679,7 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -705,6 +706,7 @@ "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -747,6 +749,16 @@ "ajv": "^8.8.2" } }, + "node_modules/baseline-browser-mapping": { + "version": "2.8.18", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.18.tgz", + "integrity": "sha512-UYmTpOBwgPScZpS4A+YbapwWuBwasxvO/2IOHArSsAhL/+ZdmATBXTex3t+l2hXwLVYK382ibr/nKoY9GKe86w==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/bootstrap": { "version": "5.3.6", "resolved": "https://registry.npmjs.org/bootstrap/-/bootstrap-5.3.6.tgz", @@ -781,9 +793,9 @@ } }, "node_modules/browserslist": { - "version": "4.25.4", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.25.4.tgz", - "integrity": "sha512-4jYpcjabC606xJ3kw2QwGEZKX0Aw7sgQdZCvIK9dhVSPh76BKo+C+btT1RRofH7B+8iNpEbgGNVWiLki5q93yg==", + "version": "4.26.3", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.26.3.tgz", + "integrity": "sha512-lAUU+02RFBuCKQPj/P6NgjlbCnLBMp4UtgTx7vNHd3XSIJF87s9a5rA3aH2yw3GS9DqZAUbOtZdCCiZeVRqt0w==", "dev": true, "funding": [ { @@ -800,10 +812,12 @@ } ], "license": "MIT", + "peer": true, "dependencies": { - "caniuse-lite": "^1.0.30001737", - "electron-to-chromium": "^1.5.211", - "node-releases": "^2.0.19", + "baseline-browser-mapping": "^2.8.9", + "caniuse-lite": "^1.0.30001746", + "electron-to-chromium": "^1.5.227", + "node-releases": "^2.0.21", "update-browserslist-db": "^1.1.3" }, "bin": { @@ -821,9 +835,9 @@ "license": "MIT" }, "node_modules/caniuse-lite": { - "version": "1.0.30001741", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001741.tgz", - "integrity": "sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==", + "version": "1.0.30001751", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001751.tgz", + "integrity": "sha512-A0QJhug0Ly64Ii3eIqHu5X51ebln3k4yTUkY1j8drqpWHVreg/VLijN48cZ1bYPiqOQuqpkIKnzr/Ul8V+p6Cw==", "dev": true, "funding": [ { @@ -975,9 +989,9 @@ } }, "node_modules/electron-to-chromium": { - "version": "1.5.215", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.215.tgz", - "integrity": "sha512-TIvGp57UpeNetj/wV/xpFNpWGb0b/ROw372lHPx5Aafx02gjTBtWnEEcaSX3W2dLM3OSdGGyHX/cHl01JQsLaQ==", + "version": "1.5.237", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.237.tgz", + "integrity": "sha512-icUt1NvfhGLar5lSWH3tHNzablaA5js3HVHacQimfP8ViEBOQv+L7DKEuHdbTZ0SKCO1ogTJTIL1Gwk9S6Qvcg==", "dev": true, "license": "ISC" }, @@ -1528,9 +1542,9 @@ "optional": true }, "node_modules/node-releases": { - "version": "2.0.20", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.20.tgz", - "integrity": "sha512-7gK6zSXEH6neM212JgfYFXe+GmZQM+fia5SsusuBIUgnPheLFBmIPhtFoAQRj8/7wASYQnbDlHPVwY0BefoFgA==", + "version": "2.0.26", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.26.tgz", + "integrity": "sha512-S2M9YimhSjBSvYnlr5/+umAnPHE++ODwt5e2Ij6FoX45HA/s4vHdkDx1eax2pAPeAOqu4s9b7ppahsyEFdVqQA==", "dev": true, "license": "MIT" }, @@ -1654,6 +1668,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -1860,11 +1875,12 @@ "license": "MIT" }, "node_modules/sass": { - "version": "1.91.0", - "resolved": "https://registry.npmjs.org/sass/-/sass-1.91.0.tgz", - "integrity": "sha512-aFOZHGf+ur+bp1bCHZ+u8otKGh77ZtmFyXDo4tlYvT7PWql41Kwd8wdkPqhhT+h2879IVblcHFglIMofsFd1EA==", + "version": "1.93.2", + "resolved": "https://registry.npmjs.org/sass/-/sass-1.93.2.tgz", + "integrity": "sha512-t+YPtOQHpGW1QWsh1CHQ5cPIr9lbbGZLZnbihP/D/qZj/yuV68m8qarcV17nvkOX81BCrvzAlq2klCQFZghyTg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "chokidar": "^4.0.0", "immutable": "^5.0.2", @@ -1922,9 +1938,9 @@ } }, "node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "license": "MIT", "dependencies": { @@ -2061,9 +2077,9 @@ } }, "node_modules/tapable": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.3.tgz", - "integrity": "sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "license": "MIT", "engines": { @@ -2210,11 +2226,12 @@ } }, "node_modules/webpack": { - "version": "5.101.3", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.101.3.tgz", - "integrity": "sha512-7b0dTKR3Ed//AD/6kkx/o7duS8H3f1a4w3BYpIriX4BzIhjkn4teo05cptsxvLesHFKK5KObnadmCHBwGc+51A==", + "version": "5.102.1", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.102.1.tgz", + "integrity": "sha512-7h/weGm9d/ywQ6qzJ+Xy+r9n/3qgp/thalBbpOi5i223dPXKi04IBtqPN9nTd+jBc7QKfvDbaBnFipYp4sJAUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/eslint-scope": "^3.7.7", "@types/estree": "^1.0.8", @@ -2224,7 +2241,7 @@ "@webassemblyjs/wasm-parser": "^1.14.1", "acorn": "^8.15.0", "acorn-import-phases": "^1.0.3", - "browserslist": "^4.24.0", + "browserslist": "^4.26.3", "chrome-trace-event": "^1.0.2", "enhanced-resolve": "^5.17.3", "es-module-lexer": "^1.2.1", @@ -2236,10 +2253,10 @@ "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^4.3.2", - "tapable": "^2.1.1", + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", "terser-webpack-plugin": "^5.3.11", - "watchpack": "^2.4.1", + "watchpack": "^2.4.4", "webpack-sources": "^3.3.3" }, "bin": { @@ -2264,6 +2281,7 @@ "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@discoveryjs/json-ext": "^0.5.0", "@webpack-cli/configtest": "^2.1.1", diff --git a/src/Admin/package.json b/src/Admin/package.json index 89ee1c5358..f6f21e2cf9 100644 --- a/src/Admin/package.json +++ b/src/Admin/package.json @@ -17,9 +17,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } } diff --git a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs index ed628105e0..a3234f61d7 100644 --- a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs +++ b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs @@ -12,10 +12,11 @@ public static class AuthorizationHandlerCollectionExtensions services.TryAddScoped(); services.TryAddEnumerable([ - ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), ]); } } diff --git a/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs new file mode 100644 index 0000000000..239148ab25 --- /dev/null +++ b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs @@ -0,0 +1,110 @@ +using System.Security.Claims; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Microsoft.AspNetCore.Authorization; + +namespace Bit.Api.AdminConsole.Authorization; + +/// +/// An authorization requirement for recovering an organization member's account. +/// +/// +/// Note: this is different to simply being able to manage account recovery. The user must be recovering +/// a member who has equal or lesser permissions than them. +/// +public class RecoverAccountAuthorizationRequirement : IAuthorizationRequirement; + +/// +/// Authorizes members and providers to recover a target OrganizationUser's account. +/// +/// +/// This prevents privilege escalation by ensuring that a user cannot recover the account of +/// another user with a higher role or with provider membership. +/// +public class RecoverAccountAuthorizationHandler( + IOrganizationContext organizationContext, + ICurrentContext currentContext, + IProviderUserRepository providerUserRepository) + : AuthorizationHandler +{ + public const string FailureReason = "You are not permitted to recover this user's account."; + public const string ProviderFailureReason = "You are not permitted to recover a Provider member's account."; + + protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, + RecoverAccountAuthorizationRequirement requirement, + OrganizationUser targetOrganizationUser) + { + // Step 1: check that the User has permissions with respect to the organization. + // This may come from their role in the organization or their provider relationship. + var canRecoverOrganizationMember = + AuthorizeMember(context.User, targetOrganizationUser) || + await AuthorizeProviderAsync(context.User, targetOrganizationUser); + + if (!canRecoverOrganizationMember) + { + context.Fail(new AuthorizationFailureReason(this, FailureReason)); + return; + } + + // Step 2: check that the User has permissions with respect to any provider the target user is a member of. + // This prevents an organization admin performing privilege escalation into an unrelated provider. + var canRecoverProviderMember = await CanRecoverProviderAsync(targetOrganizationUser); + if (!canRecoverProviderMember) + { + context.Fail(new AuthorizationFailureReason(this, ProviderFailureReason)); + return; + } + + context.Succeed(requirement); + } + + private async Task AuthorizeProviderAsync(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + return await organizationContext.IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId); + } + + private bool AuthorizeMember(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = organizationContext.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId); + if (currentContextOrganization == null) + { + return false; + } + + // Current user must have equal or greater permissions than the user account being recovered + var authorized = targetOrganizationUser.Type switch + { + OrganizationUserType.Owner => currentContextOrganization.Type is OrganizationUserType.Owner, + OrganizationUserType.Admin => currentContextOrganization.Type is OrganizationUserType.Owner or OrganizationUserType.Admin, + _ => currentContextOrganization is + { Type: OrganizationUserType.Owner or OrganizationUserType.Admin } + or { Type: OrganizationUserType.Custom, Permissions.ManageResetPassword: true } + }; + + return authorized; + } + + private async Task CanRecoverProviderAsync(OrganizationUser targetOrganizationUser) + { + if (!targetOrganizationUser.UserId.HasValue) + { + // If an OrganizationUser is not linked to a User then it can't be linked to a Provider either. + // This is invalid but does not pose a privilege escalation risk. Return early and let the command + // handle the invalid input. + return true; + } + + var targetUserProviderUsers = + await providerUserRepository.GetManyByUserAsync(targetOrganizationUser.UserId.Value); + + // If the target user belongs to any provider that the current user is not a member of, + // deny the action to prevent privilege escalation from organization to provider. + // Note: we do not expect that a user is a member of more than 1 provider, but there is also no guarantee + // against it; this returns a sequence, so we handle the possibility. + var authorized = targetUserProviderUsers.All(providerUser => currentContext.ProviderUser(providerUser.ProviderId)); + return authorized; + } +} + diff --git a/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs b/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs new file mode 100644 index 0000000000..9b147c3c54 --- /dev/null +++ b/src/Api/AdminConsole/Controllers/BaseAdminConsoleController.cs @@ -0,0 +1,26 @@ +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.Models.Api; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.AdminConsole.Controllers; + +public abstract class BaseAdminConsoleController : Controller +{ + protected static IResult Handle(CommandResult commandResult) => + commandResult.Match( + error => error switch + { + BadRequestError badRequest => TypedResults.BadRequest(new ErrorResponseModel(badRequest.Message)), + NotFoundError notFound => TypedResults.NotFound(new ErrorResponseModel(notFound.Message)), + InternalError internalError => TypedResults.Json( + new ErrorResponseModel(internalError.Message), + statusCode: StatusCodes.Status500InternalServerError), + _ => TypedResults.Json( + new ErrorResponseModel(error.Message), + statusCode: StatusCodes.Status500InternalServerError + ) + }, + _ => TypedResults.NoContent() + ); +} diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs deleted file mode 100644 index ae0f91d355..0000000000 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs +++ /dev/null @@ -1,133 +0,0 @@ -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; -using Bit.Core.Context; -using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.AdminConsole.Controllers; - -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] -[Route("organizations/{organizationId:guid}/integrations/{integrationId:guid}/configurations")] -[Authorize("Application")] -public class OrganizationIntegrationConfigurationController( - ICurrentContext currentContext, - IOrganizationIntegrationRepository integrationRepository, - IOrganizationIntegrationConfigurationRepository integrationConfigurationRepository) : Controller -{ - [HttpGet("")] - public async Task> GetAsync( - Guid organizationId, - Guid integrationId) - { - if (!await HasPermission(organizationId)) - { - throw new NotFoundException(); - } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - - var configurations = await integrationConfigurationRepository.GetManyByIntegrationAsync(integrationId); - return configurations - .Select(configuration => new OrganizationIntegrationConfigurationResponseModel(configuration)) - .ToList(); - } - - [HttpPost("")] - public async Task CreateAsync( - Guid organizationId, - Guid integrationId, - [FromBody] OrganizationIntegrationConfigurationRequestModel model) - { - if (!await HasPermission(organizationId)) - { - throw new NotFoundException(); - } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - - var organizationIntegrationConfiguration = model.ToOrganizationIntegrationConfiguration(integrationId); - var configuration = await integrationConfigurationRepository.CreateAsync(organizationIntegrationConfiguration); - return new OrganizationIntegrationConfigurationResponseModel(configuration); - } - - [HttpPut("{configurationId:guid}")] - public async Task UpdateAsync( - Guid organizationId, - Guid integrationId, - Guid configurationId, - [FromBody] OrganizationIntegrationConfigurationRequestModel model) - { - if (!await HasPermission(organizationId)) - { - throw new NotFoundException(); - } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - if (!model.IsValidForType(integration.Type)) - { - throw new BadRequestException($"Invalid Configuration and/or Template for integration type {integration.Type}"); - } - - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } - - var newConfiguration = model.ToOrganizationIntegrationConfiguration(configuration); - await integrationConfigurationRepository.ReplaceAsync(newConfiguration); - - return new OrganizationIntegrationConfigurationResponseModel(newConfiguration); - } - - [HttpDelete("{configurationId:guid}")] - public async Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) - { - if (!await HasPermission(organizationId)) - { - throw new NotFoundException(); - } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration == null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - - var configuration = await integrationConfigurationRepository.GetByIdAsync(configurationId); - if (configuration is null || configuration.OrganizationIntegrationId != integrationId) - { - throw new NotFoundException(); - } - - await integrationConfigurationRepository.DeleteAsync(configuration); - } - - [HttpPost("{configurationId:guid}/delete")] - [Obsolete("This endpoint is deprecated. Use DELETE method instead")] - public async Task PostDeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) - { - await DeleteAsync(organizationId, integrationId, configurationId); - } - - private async Task HasPermission(Guid organizationId) - { - return await currentContext.OrganizationOwner(organizationId); - } -} diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 74ac9b1255..a380d2f0d9 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below +// NOTE: This file is partially migrated to nullable reference types. Remove inline #nullable directives when addressing the FIXME. #nullable disable using Bit.Api.AdminConsole.Authorization; @@ -10,7 +11,10 @@ using Bit.Api.Models.Response; using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -18,6 +22,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Pricing; @@ -36,12 +41,14 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; +using V1_RevokeOrganizationUserCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1.IRevokeOrganizationUserCommand; +using V2_RevokeOrganizationUserCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; namespace Bit.Api.AdminConsole.Controllers; [Route("organizations/{orgId}/users")] [Authorize("Application")] -public class OrganizationUsersController : Controller +public class OrganizationUsersController : BaseAdminConsoleController { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; @@ -66,10 +73,14 @@ public class OrganizationUsersController : Controller private readonly IFeatureService _featureService; private readonly IPricingClient _pricingClient; private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand; + private readonly IBulkResendOrganizationInvitesCommand _bulkResendOrganizationInvitesCommand; + private readonly IAutomaticallyConfirmOrganizationUserCommand _automaticallyConfirmOrganizationUserCommand; + private readonly V2_RevokeOrganizationUserCommand.IRevokeOrganizationUserCommand _revokeOrganizationUserCommandVNext; private readonly IConfirmOrganizationUserCommand _confirmOrganizationUserCommand; private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; private readonly IInitPendingOrganizationCommand _initPendingOrganizationCommand; - private readonly IRevokeOrganizationUserCommand _revokeOrganizationUserCommand; + private readonly V1_RevokeOrganizationUserCommand _revokeOrganizationUserCommand; + private readonly IAdminRecoverAccountCommand _adminRecoverAccountCommand; public OrganizationUsersController(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, @@ -96,8 +107,12 @@ public class OrganizationUsersController : Controller IConfirmOrganizationUserCommand confirmOrganizationUserCommand, IRestoreOrganizationUserCommand restoreOrganizationUserCommand, IInitPendingOrganizationCommand initPendingOrganizationCommand, - IRevokeOrganizationUserCommand revokeOrganizationUserCommand, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + V1_RevokeOrganizationUserCommand revokeOrganizationUserCommand, + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IBulkResendOrganizationInvitesCommand bulkResendOrganizationInvitesCommand, + IAdminRecoverAccountCommand adminRecoverAccountCommand, + IAutomaticallyConfirmOrganizationUserCommand automaticallyConfirmOrganizationUserCommand, + V2_RevokeOrganizationUserCommand.IRevokeOrganizationUserCommand revokeOrganizationUserCommandVNext) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -122,10 +137,14 @@ public class OrganizationUsersController : Controller _featureService = featureService; _pricingClient = pricingClient; _resendOrganizationInviteCommand = resendOrganizationInviteCommand; + _bulkResendOrganizationInvitesCommand = bulkResendOrganizationInvitesCommand; + _automaticallyConfirmOrganizationUserCommand = automaticallyConfirmOrganizationUserCommand; + _revokeOrganizationUserCommandVNext = revokeOrganizationUserCommandVNext; _confirmOrganizationUserCommand = confirmOrganizationUserCommand; _restoreOrganizationUserCommand = restoreOrganizationUserCommand; _initPendingOrganizationCommand = initPendingOrganizationCommand; _revokeOrganizationUserCommand = revokeOrganizationUserCommand; + _adminRecoverAccountCommand = adminRecoverAccountCommand; } [HttpGet("{id}")] @@ -262,7 +281,17 @@ public class OrganizationUsersController : Controller public async Task> BulkReinvite(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { var userId = _userService.GetProperUserId(User); - var result = await _organizationService.ResendInvitesAsync(orgId, userId.Value, model.Ids); + + IEnumerable> result; + if (_featureService.IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud)) + { + result = await _bulkResendOrganizationInvitesCommand.BulkResendInvitesAsync(orgId, userId.Value, model.Ids); + } + else + { + result = await _organizationService.ResendInvitesAsync(orgId, userId.Value, model.Ids); + } + return new ListResponseModel( result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); } @@ -472,23 +501,31 @@ public class OrganizationUsersController : Controller } } +#nullable enable [HttpPut("{id}/reset-password")] [Authorize] - public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) + public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) { - // Get the users role, since provider users aren't a member of the organization we use the owner check - var orgUserType = await _currentContext.OrganizationOwner(orgId) - ? OrganizationUserType.Owner - : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgId)?.Type; - if (orgUserType == null) + var targetOrganizationUser = await _organizationUserRepository.GetByIdAsync(id); + if (targetOrganizationUser == null || targetOrganizationUser.OrganizationId != orgId) { - throw new NotFoundException(); + return TypedResults.NotFound(); } - var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgId, id, model.NewMasterPasswordHash, model.Key); + var authorizationResult = await _authorizationService.AuthorizeAsync(User, targetOrganizationUser, new RecoverAccountAuthorizationRequirement()); + if (!authorizationResult.Succeeded) + { + // Return an informative error to show in the UI. + // The Authorize attribute already prevents enumeration by users outside the organization, so this can be specific. + var failureReason = authorizationResult.Failure?.FailureReasons.FirstOrDefault()?.Message ?? RecoverAccountAuthorizationHandler.FailureReason; + // This should be a 403 Forbidden, but that causes a logout on our client apps so we're using 400 Bad Request instead + return TypedResults.BadRequest(new ErrorResponseModel(failureReason)); + } + + var result = await _adminRecoverAccountCommand.RecoverAccountAsync(orgId, targetOrganizationUser, model.NewMasterPasswordHash, model.Key); if (result.Succeeded) { - return; + return TypedResults.Ok(); } foreach (var error in result.Errors) @@ -497,8 +534,9 @@ public class OrganizationUsersController : Controller } await Task.Delay(2000); - throw new BadRequestException(ModelState); + return TypedResults.BadRequest(ModelState); } +#nullable disable [HttpDelete("{id}")] [Authorize] @@ -609,7 +647,29 @@ public class OrganizationUsersController : Controller [Authorize] public async Task> BulkRevokeAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - return await RestoreOrRevokeUsersAsync(orgId, model, _revokeOrganizationUserCommand.RevokeUsersAsync); + if (!_featureService.IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2)) + { + return await RestoreOrRevokeUsersAsync(orgId, model, _revokeOrganizationUserCommand.RevokeUsersAsync); + } + + var currentUserId = _userService.GetProperUserId(User); + if (currentUserId == null) + { + throw new UnauthorizedAccessException(); + } + + var results = await _revokeOrganizationUserCommandVNext.RevokeUsersAsync( + new V2_RevokeOrganizationUserCommand.RevokeOrganizationUsersRequest( + orgId, + model.Ids.ToArray(), + new StandardUser(currentUserId.Value, await _currentContext.OrganizationOwner(orgId)))); + + return new ListResponseModel(results + .Select(result => new OrganizationUserBulkResponseModel(result.Id, + result.Result.Match( + error => error.Message, + _ => string.Empty + )))); } [HttpPatch("revoke")] @@ -691,6 +751,31 @@ public class OrganizationUsersController : Controller await BulkEnableSecretsManagerAsync(orgId, model); } + [HttpPost("{id}/auto-confirm")] + [Authorize] + [RequireFeature(FeatureFlagKeys.AutomaticConfirmUsers)] + public async Task AutomaticallyConfirmOrganizationUserAsync([FromRoute] Guid orgId, + [FromRoute] Guid id, + [FromBody] OrganizationUserConfirmRequestModel model) + { + var userId = _userService.GetProperUserId(User); + + if (userId is null || userId.Value == Guid.Empty) + { + return TypedResults.Unauthorized(); + } + + return Handle(await _automaticallyConfirmOrganizationUserCommand.AutomaticallyConfirmOrganizationUserAsync( + new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationId = orgId, + OrganizationUserId = id, + Key = model.Key, + DefaultUserCollectionName = model.DefaultUserCollectionName, + PerformedBy = new StandardUser(userId.Value, await _currentContext.OrganizationOwner(orgId)), + })); + } + private async Task RestoreOrRevokeUserAsync( Guid orgId, Guid id, diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 590895665d..100cd7caf6 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -12,7 +12,6 @@ using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Request.Organizations; using Bit.Api.Models.Response; using Bit.Core; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; @@ -70,6 +69,7 @@ public class OrganizationsController : Controller private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IPricingClient _pricingClient; private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand; + private readonly IOrganizationUpdateCommand _organizationUpdateCommand; public OrganizationsController( IOrganizationRepository organizationRepository, @@ -94,7 +94,8 @@ public class OrganizationsController : Controller IOrganizationDeleteCommand organizationDeleteCommand, IPolicyRequirementQuery policyRequirementQuery, IPricingClient pricingClient, - IOrganizationUpdateKeysCommand organizationUpdateKeysCommand) + IOrganizationUpdateKeysCommand organizationUpdateKeysCommand, + IOrganizationUpdateCommand organizationUpdateCommand) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -119,6 +120,7 @@ public class OrganizationsController : Controller _policyRequirementQuery = policyRequirementQuery; _pricingClient = pricingClient; _organizationUpdateKeysCommand = organizationUpdateKeysCommand; + _organizationUpdateCommand = organizationUpdateCommand; } [HttpGet("{id}")] @@ -224,36 +226,31 @@ public class OrganizationsController : Controller return new OrganizationResponseModel(result.Organization, plan); } - [HttpPut("{id}")] - public async Task Put(string id, [FromBody] OrganizationUpdateRequestModel model) + [HttpPut("{organizationId:guid}")] + public async Task Put(Guid organizationId, [FromBody] OrganizationUpdateRequestModel model) { - var orgIdGuid = new Guid(id); + // If billing email is being changed, require subscription editing permissions. + // Otherwise, organization owner permissions are sufficient. + var requiresBillingPermission = model.BillingEmail is not null; + var authorized = requiresBillingPermission + ? await _currentContext.EditSubscription(organizationId) + : await _currentContext.OrganizationOwner(organizationId); - var organization = await _organizationRepository.GetByIdAsync(orgIdGuid); - if (organization == null) + if (!authorized) { - throw new NotFoundException(); + return TypedResults.Unauthorized(); } - var updateBilling = ShouldUpdateBilling(model, organization); + var commandRequest = model.ToCommandRequest(organizationId); + var updatedOrganization = await _organizationUpdateCommand.UpdateAsync(commandRequest); - var hasRequiredPermissions = updateBilling - ? await _currentContext.EditSubscription(orgIdGuid) - : await _currentContext.OrganizationOwner(orgIdGuid); - - if (!hasRequiredPermissions) - { - throw new NotFoundException(); - } - - await _organizationService.UpdateAsync(model.ToOrganization(organization, _globalSettings), updateBilling); - var plan = await _pricingClient.GetPlan(organization.PlanType); - return new OrganizationResponseModel(organization, plan); + var plan = await _pricingClient.GetPlan(updatedOrganization.PlanType); + return TypedResults.Ok(new OrganizationResponseModel(updatedOrganization, plan)); } [HttpPost("{id}")] [Obsolete("This endpoint is deprecated. Use PUT method instead")] - public async Task PostPut(string id, [FromBody] OrganizationUpdateRequestModel model) + public async Task PostPut(Guid id, [FromBody] OrganizationUpdateRequestModel model) { return await Put(id, model); } @@ -588,11 +585,4 @@ public class OrganizationsController : Controller return organization.PlanType; } - - private bool ShouldUpdateBilling(OrganizationUpdateRequestModel model, Organization organization) - { - var organizationNameChanged = model.Name != organization.Name; - var billingEmailChanged = model.BillingEmail != organization.BillingEmail; - return !_globalSettings.SelfHosted && (organizationNameChanged || billingEmailChanged); - } } diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index ce92321833..ae1d12e887 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -12,6 +12,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Context; @@ -41,8 +42,8 @@ public class PoliciesController : Controller private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IPolicyRepository _policyRepository; private readonly IUserService _userService; - private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public PoliciesController(IPolicyRepository policyRepository, IOrganizationUserRepository organizationUserRepository, @@ -53,7 +54,8 @@ public class PoliciesController : Controller IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IOrganizationRepository organizationRepository, - ISavePolicyCommand savePolicyCommand) + ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand) { _policyRepository = policyRepository; _organizationUserRepository = organizationUserRepository; @@ -66,6 +68,7 @@ public class PoliciesController : Controller _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } [HttpGet("{type}")] @@ -203,27 +206,20 @@ public class PoliciesController : Controller throw new NotFoundException(); } - if (type != model.Type) - { - throw new BadRequestException("Mismatched policy type"); - } - - var policyUpdate = await model.ToPolicyUpdateAsync(orgId, _currentContext); + var policyUpdate = await model.ToPolicyUpdateAsync(orgId, type, _currentContext); var policy = await _savePolicyCommand.SaveAsync(policyUpdate); return new PolicyResponseModel(policy); } - [HttpPut("{type}/vnext")] [RequireFeatureAttribute(FeatureFlagKeys.CreateDefaultLocation)] [Authorize] - public async Task PutVNext(Guid orgId, [FromBody] SavePolicyRequest model) + public async Task PutVNext(Guid orgId, PolicyType type, [FromBody] SavePolicyRequest model) { - var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, _currentContext); + var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, type, _currentContext); - var policy = await _savePolicyCommand.VNextSaveAsync(savePolicyRequest); + var policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyRequest); return new PolicyResponseModel(policy); } - } diff --git a/src/Api/AdminConsole/Controllers/ProviderClientsController.cs b/src/Api/AdminConsole/Controllers/ProviderClientsController.cs index caf2651e16..dfa6984826 100644 --- a/src/Api/AdminConsole/Controllers/ProviderClientsController.cs +++ b/src/Api/AdminConsole/Controllers/ProviderClientsController.cs @@ -57,8 +57,7 @@ public class ProviderClientsController( Owner = user, BillingEmail = provider.BillingEmail, OwnerKey = requestBody.Key, - PublicKey = requestBody.KeyPair.PublicKey, - PrivateKey = requestBody.KeyPair.EncryptedPrivateKey, + Keys = requestBody.KeyPair.ToPublicKeyEncryptionKeyPairData(), CollectionName = requestBody.CollectionName, IsFromProvider = true }; diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index aa87bf9c74..515404e8a9 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -5,6 +5,7 @@ using Bit.Api.AdminConsole.Models.Request.Providers; using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Services; @@ -23,15 +24,20 @@ public class ProvidersController : Controller private readonly IProviderService _providerService; private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; + private readonly IProviderBillingService _providerBillingService; + private readonly ILogger _logger; public ProvidersController(IUserService userService, IProviderRepository providerRepository, - IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings) + IProviderService providerService, ICurrentContext currentContext, GlobalSettings globalSettings, + IProviderBillingService providerBillingService, ILogger logger) { _userService = userService; _providerRepository = providerRepository; _providerService = providerService; _currentContext = currentContext; _globalSettings = globalSettings; + _providerBillingService = providerBillingService; + _logger = logger; } [HttpGet("{id:guid}")] @@ -65,7 +71,27 @@ public class ProvidersController : Controller throw new NotFoundException(); } + // Capture original values before modifications for Stripe sync + var originalName = provider.Name; + var originalBillingEmail = provider.BillingEmail; + await _providerService.UpdateAsync(model.ToProvider(provider, _globalSettings)); + + // Sync name/email changes to Stripe + if (originalName != provider.Name || originalBillingEmail != provider.BillingEmail) + { + try + { + await _providerBillingService.UpdateProviderNameAndEmail(provider); + } + catch (Exception ex) + { + _logger.LogError(ex, + "Failed to update Stripe customer for provider {ProviderId}. Database was updated successfully.", + provider.Id); + } + } + return new ProviderResponseModel(provider); } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs index 7754c44c8c..464ba0c2fd 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationCreateRequestModel.cs @@ -113,11 +113,10 @@ public class OrganizationCreateRequestModel : IValidatableObject BillingAddressCountry = BillingAddressCountry, }, InitiationPath = InitiationPath, - SkipTrial = SkipTrial + SkipTrial = SkipTrial, + Keys = Keys?.ToPublicKeyEncryptionKeyPairData() }; - Keys?.ToOrganizationSignup(orgSignup); - return orgSignup; } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs deleted file mode 100644 index 7d1efe2315..0000000000 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs +++ /dev/null @@ -1,104 +0,0 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; - - -namespace Bit.Api.AdminConsole.Models.Request.Organizations; - -public class OrganizationIntegrationConfigurationRequestModel -{ - public string? Configuration { get; set; } - - public EventType? EventType { get; set; } - - public string? Filters { get; set; } - - public string? Template { get; set; } - - public bool IsValidForType(IntegrationType integrationType) - { - switch (integrationType) - { - case IntegrationType.CloudBillingSync or IntegrationType.Scim: - return false; - case IntegrationType.Slack: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Webhook: - return !string.IsNullOrWhiteSpace(Template) && - IsConfigurationValid() && - IsFiltersValid(); - case IntegrationType.Hec: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - case IntegrationType.Datadog: - return !string.IsNullOrWhiteSpace(Template) && - Configuration is null && - IsFiltersValid(); - default: - return false; - - } - } - - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(Guid organizationIntegrationId) - { - return new OrganizationIntegrationConfiguration() - { - OrganizationIntegrationId = organizationIntegrationId, - Configuration = Configuration, - Filters = Filters, - EventType = EventType, - Template = Template - }; - } - - public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(OrganizationIntegrationConfiguration currentConfiguration) - { - currentConfiguration.Configuration = Configuration; - currentConfiguration.EventType = EventType; - currentConfiguration.Filters = Filters; - currentConfiguration.Template = Template; - - return currentConfiguration; - } - - private bool IsConfigurationValid() - { - if (string.IsNullOrWhiteSpace(Configuration)) - { - return false; - } - - try - { - var config = JsonSerializer.Deserialize(Configuration); - return config is not null; - } - catch - { - return false; - } - } - - private bool IsFiltersValid() - { - if (Filters is null) - { - return true; - } - - try - { - var filters = JsonSerializer.Deserialize(Filters); - return filters is not null; - } - catch - { - return false; - } - } -} diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationKeysRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationKeysRequestModel.cs index 22b225a689..ef2fb0f07b 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationKeysRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationKeysRequestModel.cs @@ -2,8 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Models.Business; +using Bit.Core.KeyManagement.Models.Data; namespace Bit.Api.AdminConsole.Models.Request.Organizations; @@ -14,48 +13,10 @@ public class OrganizationKeysRequestModel [Required] public string EncryptedPrivateKey { get; set; } - public OrganizationSignup ToOrganizationSignup(OrganizationSignup existingSignup) + public PublicKeyEncryptionKeyPairData ToPublicKeyEncryptionKeyPairData() { - if (string.IsNullOrWhiteSpace(existingSignup.PublicKey)) - { - existingSignup.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingSignup.PrivateKey)) - { - existingSignup.PrivateKey = EncryptedPrivateKey; - } - - return existingSignup; - } - - public OrganizationUpgrade ToOrganizationUpgrade(OrganizationUpgrade existingUpgrade) - { - if (string.IsNullOrWhiteSpace(existingUpgrade.PublicKey)) - { - existingUpgrade.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingUpgrade.PrivateKey)) - { - existingUpgrade.PrivateKey = EncryptedPrivateKey; - } - - return existingUpgrade; - } - - public Organization ToOrganization(Organization existingOrg) - { - if (string.IsNullOrWhiteSpace(existingOrg.PublicKey)) - { - existingOrg.PublicKey = PublicKey; - } - - if (string.IsNullOrWhiteSpace(existingOrg.PrivateKey)) - { - existingOrg.PrivateKey = EncryptedPrivateKey; - } - - return existingOrg; + return new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: EncryptedPrivateKey, + publicKey: PublicKey); } } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationNoPaymentCreateRequest.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationNoPaymentCreateRequest.cs index 0c62b23518..81d7c413eb 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationNoPaymentCreateRequest.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationNoPaymentCreateRequest.cs @@ -110,10 +110,9 @@ public class OrganizationNoPaymentCreateRequest BillingAddressCountry = BillingAddressCountry, }, InitiationPath = InitiationPath, + Keys = Keys?.ToPublicKeyEncryptionKeyPairData() }; - Keys?.ToOrganizationSignup(orgSignup); - return orgSignup; } } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs index 5a3192c121..a0b1247ae1 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpdateRequestModel.cs @@ -1,41 +1,27 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Models.Data; -using Bit.Core.Settings; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request.Organizations; public class OrganizationUpdateRequestModel { - [Required] [StringLength(50, ErrorMessage = "The field Name exceeds the maximum length.")] [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } - [StringLength(50, ErrorMessage = "The field Business Name exceeds the maximum length.")] - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string BusinessName { get; set; } - [EmailAddress] - [Required] - [StringLength(256)] - public string BillingEmail { get; set; } - public Permissions Permissions { get; set; } - public OrganizationKeysRequestModel Keys { get; set; } + public string? Name { get; set; } - public virtual Organization ToOrganization(Organization existingOrganization, GlobalSettings globalSettings) + [EmailAddress] + [StringLength(256)] + public string? BillingEmail { get; set; } + + public OrganizationKeysRequestModel? Keys { get; set; } + + public OrganizationUpdateRequest ToCommandRequest(Guid organizationId) => new() { - if (!globalSettings.SelfHosted) - { - // These items come from the license file - existingOrganization.Name = Name; - existingOrganization.BusinessName = BusinessName; - existingOrganization.BillingEmail = BillingEmail?.ToLowerInvariant()?.Trim(); - } - Keys?.ToOrganization(existingOrganization); - return existingOrganization; - } + OrganizationId = organizationId, + Name = Name, + BillingEmail = BillingEmail, + Keys = Keys?.ToPublicKeyEncryptionKeyPairData() + }; } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs index a5dec192b9..7d5a9e56c7 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUpgradeRequestModel.cs @@ -43,11 +43,10 @@ public class OrganizationUpgradeRequestModel { BillingAddressCountry = BillingAddressCountry, BillingAddressPostalCode = BillingAddressPostalCode - } + }, + Keys = Keys?.ToPublicKeyEncryptionKeyPairData() }; - Keys?.ToOrganizationUpgrade(orgUpgrade); - return orgUpgrade; } } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs index 4e0accb9e8..b7a4db3acd 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -119,7 +119,7 @@ public class OrganizationUserResetPasswordEnrollmentRequestModel public class OrganizationUserBulkRequestModel { - [Required] + [Required, MinLength(1)] public IEnumerable Ids { get; set; } } diff --git a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs index 0e31deacd1..2dc7dfa7cd 100644 --- a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs @@ -1,29 +1,30 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using System.Text.Json; +using System.ComponentModel.DataAnnotations; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; namespace Bit.Api.AdminConsole.Models.Request; public class PolicyRequestModel { - [Required] - public PolicyType? Type { get; set; } [Required] public bool? Enabled { get; set; } - public Dictionary Data { get; set; } + public Dictionary? Data { get; set; } - public async Task ToPolicyUpdateAsync(Guid organizationId, ICurrentContext currentContext) => new() + public async Task ToPolicyUpdateAsync(Guid organizationId, PolicyType type, ICurrentContext currentContext) { - Type = Type!.Value, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); + + return new() + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = performedBy + }; + } } diff --git a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs index fcdc49882b..2e2868a78a 100644 --- a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs +++ b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -using System.Text.Json; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; -using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request; @@ -15,47 +14,12 @@ public class SavePolicyRequest public Dictionary? Metadata { get; set; } - public async Task ToSavePolicyModelAsync(Guid organizationId, ICurrentContext currentContext) + public async Task ToSavePolicyModelAsync(Guid organizationId, PolicyType type, ICurrentContext currentContext) { + var policyUpdate = await Policy.ToPolicyUpdateAsync(organizationId, type, currentContext); + var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, type); var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); - var updatedPolicy = new PolicyUpdate() - { - Type = Policy.Type!.Value, - OrganizationId = organizationId, - Data = Policy.Data != null ? JsonSerializer.Serialize(Policy.Data) : null, - Enabled = Policy.Enabled.GetValueOrDefault(), - }; - - var metadata = MapToPolicyMetadata(); - - return new SavePolicyModel(updatedPolicy, performedBy, metadata); - } - - private IPolicyMetadataModel MapToPolicyMetadata() - { - if (Metadata == null) - { - return new EmptyMetadataModel(); - } - - return Policy?.Type switch - { - PolicyType.OrganizationDataOwnership => MapToPolicyMetadata(), - _ => new EmptyMetadataModel() - }; - } - - private IPolicyMetadataModel MapToPolicyMetadata() where T : IPolicyMetadataModel, new() - { - try - { - var json = JsonSerializer.Serialize(Metadata); - return CoreHelpers.LoadClassFromJsonData(json); - } - catch - { - return new EmptyMetadataModel(); - } + return new SavePolicyModel(policyUpdate, performedBy, metadata); } } diff --git a/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs new file mode 100644 index 0000000000..f5ef468b4e --- /dev/null +++ b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs @@ -0,0 +1,129 @@ +using System.Text.Json.Serialization; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Api.AdminConsole.Models.Response; + +/// +/// Contains organization properties for both OrganizationUsers and ProviderUsers. +/// Any organization properties in sync data should be added to this class so they are populated for both +/// members and providers. +/// +public abstract class BaseProfileOrganizationResponseModel : ResponseModel +{ + protected BaseProfileOrganizationResponseModel( + string type, IProfileOrganizationDetails organizationDetails) : base(type) + { + Id = organizationDetails.OrganizationId; + UserId = organizationDetails.UserId; + Name = organizationDetails.Name; + Enabled = organizationDetails.Enabled; + Identifier = organizationDetails.Identifier; + ProductTierType = organizationDetails.PlanType.GetProductTier(); + UsePolicies = organizationDetails.UsePolicies; + UseSso = organizationDetails.UseSso; + UseKeyConnector = organizationDetails.UseKeyConnector; + UseScim = organizationDetails.UseScim; + UseGroups = organizationDetails.UseGroups; + UseDirectory = organizationDetails.UseDirectory; + UseEvents = organizationDetails.UseEvents; + UseTotp = organizationDetails.UseTotp; + Use2fa = organizationDetails.Use2fa; + UseApi = organizationDetails.UseApi; + UseResetPassword = organizationDetails.UseResetPassword; + UsersGetPremium = organizationDetails.UsersGetPremium; + UseCustomPermissions = organizationDetails.UseCustomPermissions; + UseActivateAutofillPolicy = organizationDetails.PlanType.GetProductTier() == ProductTierType.Enterprise; + UseRiskInsights = organizationDetails.UseRiskInsights; + UseOrganizationDomains = organizationDetails.UseOrganizationDomains; + UseAdminSponsoredFamilies = organizationDetails.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organizationDetails.UseAutomaticUserConfirmation; + UseSecretsManager = organizationDetails.UseSecretsManager; + UsePhishingBlocker = organizationDetails.UsePhishingBlocker; + UsePasswordManager = organizationDetails.UsePasswordManager; + SelfHost = organizationDetails.SelfHost; + Seats = organizationDetails.Seats; + MaxCollections = organizationDetails.MaxCollections; + MaxStorageGb = organizationDetails.MaxStorageGb; + Key = organizationDetails.Key; + HasPublicAndPrivateKeys = organizationDetails.PublicKey != null && organizationDetails.PrivateKey != null; + SsoBound = !string.IsNullOrWhiteSpace(organizationDetails.SsoExternalId); + ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(organizationDetails.ResetPasswordKey); + ProviderId = organizationDetails.ProviderId; + ProviderName = organizationDetails.ProviderName; + ProviderType = organizationDetails.ProviderType; + LimitCollectionCreation = organizationDetails.LimitCollectionCreation; + LimitCollectionDeletion = organizationDetails.LimitCollectionDeletion; + LimitItemDeletion = organizationDetails.LimitItemDeletion; + AllowAdminAccessToAllCollectionItems = organizationDetails.AllowAdminAccessToAllCollectionItems; + SsoEnabled = organizationDetails.SsoEnabled ?? false; + if (organizationDetails.SsoConfig != null) + { + var ssoConfigData = SsoConfigurationData.Deserialize(organizationDetails.SsoConfig); + KeyConnectorEnabled = ssoConfigData.MemberDecryptionType == MemberDecryptionType.KeyConnector && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); + KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; + SsoMemberDecryptionType = ssoConfigData.MemberDecryptionType; + } + } + + public Guid Id { get; set; } + [JsonConverter(typeof(HtmlEncodingStringConverter))] + public string Name { get; set; } = null!; + public bool Enabled { get; set; } + public string? Identifier { get; set; } + public ProductTierType ProductTierType { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UseSecretsManager { get; set; } + public bool UsePasswordManager { get; set; } + public bool UsersGetPremium { get; set; } + public bool UseCustomPermissions { get; set; } + public bool UseActivateAutofillPolicy { get; set; } + public bool UseRiskInsights { get; set; } + public bool UseOrganizationDomains { get; set; } + public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } + public bool SelfHost { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string? Key { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } + public bool SsoBound { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public bool LimitCollectionCreation { get; set; } + public bool LimitCollectionDeletion { get; set; } + public bool LimitItemDeletion { get; set; } + public bool AllowAdminAccessToAllCollectionItems { get; set; } + public Guid? ProviderId { get; set; } + [JsonConverter(typeof(HtmlEncodingStringConverter))] + public string? ProviderName { get; set; } + public ProviderType? ProviderType { get; set; } + public bool SsoEnabled { get; set; } + public bool KeyConnectorEnabled { get; set; } + public string? KeyConnectorUrl { get; set; } + public MemberDecryptionType? SsoMemberDecryptionType { get; set; } + public bool AccessSecretsManager { get; set; } + public Guid? UserId { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public Permissions? Permissions { get; set; } +} diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs index b34765fb19..9a3543f4bb 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs @@ -1,10 +1,13 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using System.Security.Claims; using System.Text.Json.Serialization; using Bit.Api.Models.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Models.Api; using Bit.Core.Models.Business; @@ -70,6 +73,8 @@ public class OrganizationResponseModel : ResponseModel UseRiskInsights = organization.UseRiskInsights; UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; + UsePhishingBlocker = organization.UsePhishingBlocker; } public Guid Id { get; set; } @@ -118,6 +123,8 @@ public class OrganizationResponseModel : ResponseModel public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } public class OrganizationSubscriptionResponseModel : OrganizationResponseModel @@ -173,6 +180,30 @@ public class OrganizationSubscriptionResponseModel : OrganizationResponseModel } } + public OrganizationSubscriptionResponseModel(Organization organization, OrganizationLicense license, ClaimsPrincipal claimsPrincipal) : + this(organization, (Plan)null) + { + if (license != null) + { + // CRITICAL: When a license has a Token (JWT), ALWAYS use the expiration from the token claim + // The token's expiration is cryptographically secured and cannot be tampered with + // The file's Expires property can be manually edited and should NOT be trusted for display + if (claimsPrincipal != null) + { + Expiration = claimsPrincipal.GetValue(OrganizationLicenseConstants.Expires); + ExpirationWithoutGracePeriod = claimsPrincipal.GetValue(OrganizationLicenseConstants.ExpirationWithoutGracePeriod); + } + else + { + // No token - use the license file expiration (for older licenses without tokens) + Expiration = license.Expires; + ExpirationWithoutGracePeriod = license.ExpirationWithoutGracePeriod ?? (license.Trial + ? license.Expires + : license.Expires?.AddDays(-Constants.OrganizationSelfHostSubscriptionGracePeriodDays)); + } + } + } + public string StorageName { get; set; } public double? StorageGb { get; set; } public BillingCustomerDiscount CustomerDiscount { get; set; } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs index 81ca801308..0507de7a55 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/PolicyResponseModel.cs @@ -30,6 +30,7 @@ public class PolicyResponseModel : ResponseModel { Data = JsonSerializer.Deserialize>(policy.Data); } + RevisionDate = policy.RevisionDate; } public Guid Id { get; set; } @@ -37,4 +38,5 @@ public class PolicyResponseModel : ResponseModel public PolicyType Type { get; set; } public Dictionary Data { get; set; } public bool Enabled { get; set; } + public DateTime RevisionDate { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs index fd2bfe06dc..8c52092dae 100644 --- a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs @@ -1,148 +1,48 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models.Data; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Enums; -using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Response; -public class ProfileOrganizationResponseModel : ResponseModel +/// +/// Sync data for organization members and their organization. +/// Note: see for organization sync data received by provider users. +/// +public class ProfileOrganizationResponseModel : BaseProfileOrganizationResponseModel { - public ProfileOrganizationResponseModel(string str) : base(str) { } - public ProfileOrganizationResponseModel( - OrganizationUserOrganizationDetails organization, + OrganizationUserOrganizationDetails organizationDetails, IEnumerable organizationIdsClaimingUser) - : this("profileOrganization") + : base("profileOrganization", organizationDetails) { - Id = organization.OrganizationId; - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UseSecretsManager = organization.UseSecretsManager; - UsePasswordManager = organization.UsePasswordManager; - UsersGetPremium = organization.UsersGetPremium; - UseCustomPermissions = organization.UseCustomPermissions; - UseActivateAutofillPolicy = organization.PlanType.GetProductTier() == ProductTierType.Enterprise; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = organization.Status; - Type = organization.Type; - Enabled = organization.Enabled; - SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); - Identifier = organization.Identifier; - Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); - ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(organization.ResetPasswordKey); - UserId = organization.UserId; - OrganizationUserId = organization.OrganizationUserId; - ProviderId = organization.ProviderId; - ProviderName = organization.ProviderName; - ProviderType = organization.ProviderType; - FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; - IsAdminInitiated = organization.IsAdminInitiated ?? false; - FamilySponsorshipAvailable = (FamilySponsorshipFriendlyName == null || IsAdminInitiated) && - StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) - .UsersCanSponsor(organization); - ProductTierType = organization.PlanType.GetProductTier(); - FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; - FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; - FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; - AccessSecretsManager = organization.AccessSecretsManager; - LimitCollectionCreation = organization.LimitCollectionCreation; - LimitCollectionDeletion = organization.LimitCollectionDeletion; - LimitItemDeletion = organization.LimitItemDeletion; - AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; - UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organization.OrganizationId); - UseRiskInsights = organization.UseRiskInsights; - UseOrganizationDomains = organization.UseOrganizationDomains; - UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; - SsoEnabled = organization.SsoEnabled ?? false; - - if (organization.SsoConfig != null) - { - var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); - KeyConnectorEnabled = ssoConfigData.MemberDecryptionType == MemberDecryptionType.KeyConnector && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); - KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; - SsoMemberDecryptionType = ssoConfigData.MemberDecryptionType; - } + Status = organizationDetails.Status; + Type = organizationDetails.Type; + OrganizationUserId = organizationDetails.OrganizationUserId; + UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organizationDetails.OrganizationId); + Permissions = CoreHelpers.LoadClassFromJsonData(organizationDetails.Permissions); + IsAdminInitiated = organizationDetails.IsAdminInitiated ?? false; + FamilySponsorshipFriendlyName = organizationDetails.FamilySponsorshipFriendlyName; + FamilySponsorshipLastSyncDate = organizationDetails.FamilySponsorshipLastSyncDate; + FamilySponsorshipToDelete = organizationDetails.FamilySponsorshipToDelete; + FamilySponsorshipValidUntil = organizationDetails.FamilySponsorshipValidUntil; + FamilySponsorshipAvailable = (organizationDetails.FamilySponsorshipFriendlyName == null || IsAdminInitiated) && + SponsoredPlans.Get(PlanSponsorshipType.FamiliesForEnterprise) + .UsersCanSponsor(organizationDetails); + AccessSecretsManager = organizationDetails.AccessSecretsManager; } - public Guid Id { get; set; } - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UseSecretsManager { get; set; } - public bool UsePasswordManager { get; set; } - public bool UsersGetPremium { get; set; } - public bool UseCustomPermissions { get; set; } - public bool UseActivateAutofillPolicy { get; set; } - public bool SelfHost { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public bool SsoBound { get; set; } - public string Identifier { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public Guid? UserId { get; set; } public Guid OrganizationUserId { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } - public Guid? ProviderId { get; set; } - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } - public ProviderType? ProviderType { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } + public bool UserIsClaimedByOrganization { get; set; } + public string? FamilySponsorshipFriendlyName { get; set; } public bool FamilySponsorshipAvailable { get; set; } - public ProductTierType ProductTierType { get; set; } - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } public DateTime? FamilySponsorshipLastSyncDate { get; set; } public DateTime? FamilySponsorshipValidUntil { get; set; } public bool? FamilySponsorshipToDelete { get; set; } - public bool AccessSecretsManager { get; set; } - public bool LimitCollectionCreation { get; set; } - public bool LimitCollectionDeletion { get; set; } - public bool LimitItemDeletion { get; set; } - public bool AllowAdminAccessToAllCollectionItems { get; set; } + public bool IsAdminInitiated { get; set; } /// - /// Obsolete. - /// See + /// Obsolete property for backward compatibility /// [Obsolete("Please use UserIsClaimedByOrganization instead. This property will be removed in a future version.")] public bool UserIsManagedByOrganization @@ -150,18 +50,4 @@ public class ProfileOrganizationResponseModel : ResponseModel get => UserIsClaimedByOrganization; set => UserIsClaimedByOrganization = value; } - /// - /// Indicates if the user is claimed by the organization. - /// - /// - /// A user is claimed by an organization if the user's email domain is verified by the organization and the user is a member. - /// The organization must be enabled and able to have verified domains. - /// - public bool UserIsClaimedByOrganization { get; set; } - public bool UseRiskInsights { get; set; } - public bool UseOrganizationDomains { get; set; } - public bool UseAdminSponsoredFamilies { get; set; } - public bool IsAdminInitiated { get; set; } - public bool SsoEnabled { get; set; } - public MemberDecryptionType? SsoMemberDecryptionType { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs index 24b6fed704..fe31b8cb55 100644 --- a/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs @@ -1,56 +1,24 @@ using Bit.Core.AdminConsole.Models.Data.Provider; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; using Bit.Core.Enums; using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Models.Response; -public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel +/// +/// Sync data for provider users and their managed organizations. +/// Note: see for organization sync data received by organization members. +/// +public class ProfileProviderOrganizationResponseModel : BaseProfileOrganizationResponseModel { - public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) - : base("profileProviderOrganization") + public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organizationDetails) + : base("profileProviderOrganization", organizationDetails) { - Id = organization.OrganizationId; - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - UseCustomPermissions = organization.UseCustomPermissions; - UseActivateAutofillPolicy = organization.PlanType.GetProductTier() == ProductTierType.Enterprise; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed Type = OrganizationUserType.Owner; // Provider users behave like Owners - Enabled = organization.Enabled; - SsoBound = false; - Identifier = organization.Identifier; + ProviderId = organizationDetails.ProviderId; + ProviderName = organizationDetails.ProviderName; + ProviderType = organizationDetails.ProviderType; Permissions = new Permissions(); - ResetPasswordEnrolled = false; - UserId = organization.UserId; - ProviderId = organization.ProviderId; - ProviderName = organization.ProviderName; - ProviderType = organization.ProviderType; - ProductTierType = organization.PlanType.GetProductTier(); - LimitCollectionCreation = organization.LimitCollectionCreation; - LimitCollectionDeletion = organization.LimitCollectionDeletion; - LimitItemDeletion = organization.LimitItemDeletion; - AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; - UseRiskInsights = organization.UseRiskInsights; - UseOrganizationDomains = organization.UseOrganizationDomains; - UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + AccessSecretsManager = false; // Provider users cannot access Secrets Manager } } diff --git a/src/Api/AdminConsole/Public/Controllers/EventsController.cs b/src/Api/AdminConsole/Public/Controllers/EventsController.cs deleted file mode 100644 index 3dd55d51e2..0000000000 --- a/src/Api/AdminConsole/Public/Controllers/EventsController.cs +++ /dev/null @@ -1,74 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Net; -using Bit.Api.Models.Public.Request; -using Bit.Api.Models.Public.Response; -using Bit.Core.Context; -using Bit.Core.Models.Data; -using Bit.Core.Repositories; -using Bit.Core.Vault.Repositories; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Public.Controllers; - -[Route("public/events")] -[Authorize("Organization")] -public class EventsController : Controller -{ - private readonly IEventRepository _eventRepository; - private readonly ICipherRepository _cipherRepository; - private readonly ICurrentContext _currentContext; - - public EventsController( - IEventRepository eventRepository, - ICipherRepository cipherRepository, - ICurrentContext currentContext) - { - _eventRepository = eventRepository; - _cipherRepository = cipherRepository; - _currentContext = currentContext; - } - - /// - /// List all events. - /// - /// - /// Returns a filtered list of your organization's event logs, paged by a continuation token. - /// If no filters are provided, it will return the last 30 days of event for the organization. - /// - [HttpGet] - [ProducesResponseType(typeof(PagedListResponseModel), (int)HttpStatusCode.OK)] - public async Task List([FromQuery] EventFilterRequestModel request) - { - var dateRange = request.ToDateRange(); - var result = new PagedResult(); - if (request.ActingUserId.HasValue) - { - result = await _eventRepository.GetManyByOrganizationActingUserAsync( - _currentContext.OrganizationId.Value, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - else if (request.ItemId.HasValue) - { - var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); - if (cipher != null && cipher.OrganizationId == _currentContext.OrganizationId.Value) - { - result = await _eventRepository.GetManyByCipherAsync( - cipher, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - } - else - { - result = await _eventRepository.GetManyByOrganizationAsync( - _currentContext.OrganizationId.Value, dateRange.Item1, dateRange.Item2, - new PageOptions { ContinuationToken = request.ContinuationToken }); - } - - var eventResponses = result.Data.Select(e => new EventResponseModel(e)); - var response = new PagedListResponseModel(eventResponses, result.ContinuationToken); - return new JsonResult(response); - } -} diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 7bfe5648b6..58e5db18c2 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Net; +using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; @@ -9,6 +6,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; using Bit.Core.Services; @@ -24,12 +22,10 @@ public class MembersController : Controller private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IGroupRepository _groupRepository; private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; private readonly ICurrentContext _currentContext; private readonly IUpdateOrganizationUserCommand _updateOrganizationUserCommand; private readonly IUpdateOrganizationUserGroupsCommand _updateOrganizationUserGroupsCommand; - private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationRepository _organizationRepository; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; @@ -39,12 +35,10 @@ public class MembersController : Controller IOrganizationUserRepository organizationUserRepository, IGroupRepository groupRepository, IOrganizationService organizationService, - IUserService userService, ICurrentContext currentContext, IUpdateOrganizationUserCommand updateOrganizationUserCommand, IUpdateOrganizationUserGroupsCommand updateOrganizationUserGroupsCommand, - IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IRemoveOrganizationUserCommand removeOrganizationUserCommand, @@ -53,11 +47,9 @@ public class MembersController : Controller _organizationUserRepository = organizationUserRepository; _groupRepository = groupRepository; _organizationService = organizationService; - _userService = userService; _currentContext = currentContext; _updateOrganizationUserCommand = updateOrganizationUserCommand; _updateOrganizationUserGroupsCommand = updateOrganizationUserGroupsCommand; - _applicationCacheService = applicationCacheService; _paymentService = paymentService; _organizationRepository = organizationRepository; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; @@ -115,19 +107,18 @@ public class MembersController : Controller ///
/// /// Returns a list of your organization's members. - /// Member objects listed in this call do not include information about their associated collections. + /// Member objects listed in this call include information about their associated collections. /// [HttpGet] [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] public async Task List() { - var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId.Value); - // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. + var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId!.Value, includeCollections: true); var orgUsersTwoFactorIsEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(organizationUserUserDetails); var memberResponses = organizationUserUserDetails.Select(u => { - return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, null); + return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, u.Collections); }); var response = new ListResponseModel(memberResponses); return new JsonResult(response); @@ -158,7 +149,7 @@ public class MembersController : Controller invite.AccessSecretsManager = hasStandaloneSecretsManager; - var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, + var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId!.Value, null, systemUser: null, invite, model.ExternalId); var response = new MemberResponseModel(user, invite.Collections); return new JsonResult(response); @@ -188,12 +179,12 @@ public class MembersController : Controller var updatedUser = model.ToOrganizationUser(existingUser); var associations = model.Collections?.Select(c => c.ToCollectionAccessSelection()).ToList(); await _updateOrganizationUserCommand.UpdateUserAsync(updatedUser, existingUserType, null, associations, model.Groups); - MemberResponseModel response = null; + MemberResponseModel response; if (existingUser.UserId.HasValue) { var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - response = new MemberResponseModel(existingUserDetails, - await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails), associations); + response = new MemberResponseModel(existingUserDetails!, + await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails!), associations); } else { @@ -242,7 +233,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId.Value, id, null); + await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId!.Value, id, null); return new OkResult(); } @@ -264,7 +255,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); + await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId!.Value, null, id); return new OkResult(); } } diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs index 1caf9cb068..cf8da813be 100644 --- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs @@ -6,9 +6,8 @@ using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Context; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -20,20 +19,17 @@ namespace Bit.Api.AdminConsole.Public.Controllers; public class PoliciesController : Controller { private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; private readonly ICurrentContext _currentContext; - private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public PoliciesController( IPolicyRepository policyRepository, - IPolicyService policyService, ICurrentContext currentContext, - ISavePolicyCommand savePolicyCommand) + IVNextSavePolicyCommand vNextSavePolicyCommand) { _policyRepository = policyRepository; - _policyService = policyService; _currentContext = currentContext; - _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } /// @@ -87,8 +83,8 @@ public class PoliciesController : Controller [ProducesResponseType((int)HttpStatusCode.NotFound)] public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) { - var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type); - var policy = await _savePolicyCommand.SaveAsync(policyUpdate); + var savePolicyModel = model.ToSavePolicyModel(_currentContext.OrganizationId!.Value, type); + var policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyModel); var response = new PolicyResponseModel(policy); return new JsonResult(response); diff --git a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs index eb56690462..f81d9153b2 100644 --- a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs @@ -1,19 +1,44 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Enums; namespace Bit.Api.AdminConsole.Public.Models.Request; public class PolicyUpdateRequestModel : PolicyBaseModel { - public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) => new() + public Dictionary? Metadata { get; set; } + + public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) { - Type = type, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new SystemUser(EventSystemUser.PublicApi) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + + return new() + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = new SystemUser(EventSystemUser.PublicApi) + }; + } + + public SavePolicyModel ToSavePolicyModel(Guid organizationId, PolicyType type) + { + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + + var policyUpdate = new PolicyUpdate + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault() + }; + + var performedBy = new SystemUser(EventSystemUser.PublicApi); + var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, type); + + return new SavePolicyModel(policyUpdate, performedBy, metadata); + } } diff --git a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs index e319ead8a4..5ff12a2201 100644 --- a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs @@ -1,9 +1,15 @@ -using Bit.Core.Models.Data; +using System.Text.Json.Serialization; +using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Public.Models.Response; public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel { + [JsonConstructor] + public AssociationWithPermissionsResponseModel() : base() + { + } + public AssociationWithPermissionsResponseModel(CollectionAccessSelection selection) { if (selection == null) diff --git a/src/Api/AdminConsole/Public/Models/Response/GroupResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/GroupResponseModel.cs index c12616b4cc..e164f3c4ea 100644 --- a/src/Api/AdminConsole/Public/Models/Response/GroupResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/GroupResponseModel.cs @@ -2,6 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; using Bit.Api.Models.Public.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.Models.Data; @@ -13,6 +14,12 @@ namespace Bit.Api.AdminConsole.Public.Models.Response; /// public class GroupResponseModel : GroupBaseModel, IResponseModel { + [JsonConstructor] + public GroupResponseModel() + { + + } + public GroupResponseModel(Group group, IEnumerable collections) { if (group == null) diff --git a/src/Api/Api.csproj b/src/Api/Api.csproj index 138549e92d..dd27de2e63 100644 --- a/src/Api/Api.csproj +++ b/src/Api/Api.csproj @@ -33,7 +33,7 @@ - + diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index 19165a5a1c..839d00f7a1 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -18,6 +18,8 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Api.Response; using Bit.Core.Repositories; using Bit.Core.Services; @@ -40,8 +42,10 @@ public class AccountsController : Controller private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IFeatureService _featureService; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; + private readonly IUserRepository _userRepository; public AccountsController( IOrganizationService organizationService, @@ -53,8 +57,10 @@ public class AccountsController : Controller ITdeOffboardingPasswordCommand tdeOffboardingPasswordCommand, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IFeatureService featureService, + IUserAccountKeysQuery userAccountKeysQuery, ITwoFactorEmailService twoFactorEmailService, - IChangeKdfCommand changeKdfCommand + IChangeKdfCommand changeKdfCommand, + IUserRepository userRepository ) { _organizationService = organizationService; @@ -66,8 +72,10 @@ public class AccountsController : Controller _tdeOffboardingPasswordCommand = tdeOffboardingPasswordCommand; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _featureService = featureService; + _userAccountKeysQuery = userAccountKeysQuery; _twoFactorEmailService = twoFactorEmailService; _changeKdfCommand = changeKdfCommand; + _userRepository = userRepository; } @@ -332,7 +340,9 @@ public class AccountsController : Controller var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); - var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + var accountKeys = await _userAccountKeysQuery.Run(user); + + var response = new ProfileResponseModel(user, accountKeys, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); return response; @@ -364,8 +374,9 @@ public class AccountsController : Controller var twoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var userAccountKeys = await _userAccountKeysQuery.Run(user); - var response = new ProfileResponseModel(user, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); + var response = new ProfileResponseModel(user, userAccountKeys, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); return response; } @@ -389,8 +400,9 @@ public class AccountsController : Controller var userTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var accountKeys = await _userAccountKeysQuery.Run(user); - var response = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); + var response = new ProfileResponseModel(user, accountKeys, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); return response; } @@ -424,16 +436,36 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - if (_featureService.IsEnabled(FeatureFlagKeys.ReturnErrorOnExistingKeypair)) + if (!string.IsNullOrWhiteSpace(user.PrivateKey) || !string.IsNullOrWhiteSpace(user.PublicKey)) { - if (!string.IsNullOrWhiteSpace(user.PrivateKey) || !string.IsNullOrWhiteSpace(user.PublicKey)) - { - throw new BadRequestException("User has existing keypair"); - } + throw new BadRequestException("User has existing keypair"); + } + + if (model.AccountKeys != null) + { + var accountKeysData = model.AccountKeys.ToAccountKeysData(); + if (!accountKeysData.IsV2Encryption()) + { + throw new BadRequestException("AccountKeys are only supported for V2 encryption."); + } + await _userRepository.SetV2AccountCryptographicStateAsync(user.Id, accountKeysData); + return new KeysResponseModel(accountKeysData, user.Key); + } + else + { + // Todo: Drop this after a transition period. This will drop no-account-keys requests. + // The V1 check in the other branch should persist + // https://bitwarden.atlassian.net/browse/PM-27329 + await _userService.SaveUserAsync(model.ToUser(user)); + return new KeysResponseModel(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( + user.PrivateKey, + user.PublicKey + ) + }, user.Key); } - await _userService.SaveUserAsync(model.ToUser(user)); - return new KeysResponseModel(user); } [HttpGet("keys")] @@ -445,7 +477,8 @@ public class AccountsController : Controller throw new UnauthorizedAccessException(); } - return new KeysResponseModel(user); + var accountKeys = await _userAccountKeysQuery.Run(user); + return new KeysResponseModel(accountKeys, user.Key); } [HttpDelete] diff --git a/src/Api/Auth/Controllers/TwoFactorController.cs b/src/Api/Auth/Controllers/TwoFactorController.cs index 0af46fb57c..ba6cf66859 100644 --- a/src/Api/Auth/Controllers/TwoFactorController.cs +++ b/src/Api/Auth/Controllers/TwoFactorController.cs @@ -9,7 +9,6 @@ using Bit.Api.Models.Response; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Identity.TokenProviders; -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Services; using Bit.Core.Context; @@ -35,7 +34,7 @@ public class TwoFactorController : Controller private readonly IOrganizationService _organizationService; private readonly UserManager _userManager; private readonly ICurrentContext _currentContext; - private readonly IVerifyAuthRequestCommand _verifyAuthRequestCommand; + private readonly IAuthRequestRepository _authRequestRepository; private readonly IDuoUniversalTokenService _duoUniversalTokenService; private readonly IDataProtectorTokenFactory _twoFactorAuthenticatorDataProtector; private readonly IDataProtectorTokenFactory _ssoEmailTwoFactorSessionDataProtector; @@ -47,7 +46,7 @@ public class TwoFactorController : Controller IOrganizationService organizationService, UserManager userManager, ICurrentContext currentContext, - IVerifyAuthRequestCommand verifyAuthRequestCommand, + IAuthRequestRepository authRequestRepository, IDuoUniversalTokenService duoUniversalConfigService, IDataProtectorTokenFactory twoFactorAuthenticatorDataProtector, IDataProtectorTokenFactory ssoEmailTwoFactorSessionDataProtector, @@ -58,7 +57,7 @@ public class TwoFactorController : Controller _organizationService = organizationService; _userManager = userManager; _currentContext = currentContext; - _verifyAuthRequestCommand = verifyAuthRequestCommand; + _authRequestRepository = authRequestRepository; _duoUniversalTokenService = duoUniversalConfigService; _twoFactorAuthenticatorDataProtector = twoFactorAuthenticatorDataProtector; _ssoEmailTwoFactorSessionDataProtector = ssoEmailTwoFactorSessionDataProtector; @@ -350,14 +349,15 @@ public class TwoFactorController : Controller if (user != null) { - // Check if 2FA email is from Passwordless. + // Check if 2FA email is from a device approval ("Log in with device") scenario. if (!string.IsNullOrEmpty(requestModel.AuthRequestAccessCode)) { - if (await _verifyAuthRequestCommand - .VerifyAuthRequestAsync(new Guid(requestModel.AuthRequestId), - requestModel.AuthRequestAccessCode)) + var authRequest = await _authRequestRepository.GetByIdAsync(new Guid(requestModel.AuthRequestId)); + if (authRequest != null && + authRequest.IsValidForAuthentication(user.Id, requestModel.AuthRequestAccessCode)) { await _twoFactorEmailService.SendTwoFactorEmailAsync(user); + return; } } else if (!string.IsNullOrEmpty(requestModel.SsoEmail2FaSessionToken)) diff --git a/src/Api/Auth/Controllers/WebAuthnController.cs b/src/Api/Auth/Controllers/WebAuthnController.cs index 60b8621c5e..833087e99c 100644 --- a/src/Api/Auth/Controllers/WebAuthnController.cs +++ b/src/Api/Auth/Controllers/WebAuthnController.cs @@ -21,7 +21,6 @@ using Microsoft.AspNetCore.Mvc; namespace Bit.Api.Auth.Controllers; [Route("webauthn")] -[Authorize(Policies.Web)] public class WebAuthnController : Controller { private readonly IUserService _userService; @@ -62,6 +61,7 @@ public class WebAuthnController : Controller _featureService = featureService; } + [Authorize(Policies.Web)] [HttpGet("")] public async Task> Get() { @@ -71,6 +71,7 @@ public class WebAuthnController : Controller return new ListResponseModel(credentials.Select(c => new WebAuthnCredentialResponseModel(c))); } + [Authorize(Policies.Application)] [HttpPost("attestation-options")] public async Task AttestationOptions([FromBody] SecretVerificationRequestModel model) { @@ -88,6 +89,7 @@ public class WebAuthnController : Controller }; } + [Authorize(Policies.Web)] [HttpPost("assertion-options")] public async Task AssertionOptions([FromBody] SecretVerificationRequestModel model) { @@ -104,6 +106,7 @@ public class WebAuthnController : Controller }; } + [Authorize(Policies.Application)] [HttpPost("")] public async Task Post([FromBody] WebAuthnLoginCredentialCreateRequestModel model) { @@ -149,6 +152,7 @@ public class WebAuthnController : Controller } } + [Authorize(Policies.Application)] [HttpPut()] public async Task UpdateCredential([FromBody] WebAuthnLoginCredentialUpdateRequestModel model) { @@ -172,6 +176,7 @@ public class WebAuthnController : Controller await _credentialRepository.UpdateAsync(credential); } + [Authorize(Policies.Web)] [HttpPost("{id}/delete")] public async Task Delete(Guid id, [FromBody] SecretVerificationRequestModel model) { diff --git a/src/Api/Auth/Models/Request/TwoFactorRequestModels.cs b/src/Api/Auth/Models/Request/TwoFactorRequestModels.cs index 79df29c928..6173de81d9 100644 --- a/src/Api/Auth/Models/Request/TwoFactorRequestModels.cs +++ b/src/Api/Auth/Models/Request/TwoFactorRequestModels.cs @@ -273,7 +273,7 @@ public class TwoFactorWebAuthnDeleteRequestModel : SecretVerificationRequestMode yield return validationResult; } - if (!Id.HasValue || Id < 0 || Id > 5) + if (!Id.HasValue) { yield return new ValidationResult("Invalid Key Id", new string[] { nameof(Id) }); } diff --git a/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs b/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs new file mode 100644 index 0000000000..7a906d4838 --- /dev/null +++ b/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs @@ -0,0 +1,13 @@ +using Bit.Api.Utilities; + +namespace Bit.Api.Billing.Attributes; + +public class NonTokenizedPaymentMethodTypeValidationAttribute : StringMatchesAttribute +{ + private static readonly string[] _acceptedValues = ["accountCredit"]; + + public NonTokenizedPaymentMethodTypeValidationAttribute() : base(_acceptedValues) + { + ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}"; + } +} diff --git a/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs b/src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs similarity index 62% rename from src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs rename to src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs index 227b454f9f..51e40e9999 100644 --- a/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs +++ b/src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs @@ -2,11 +2,11 @@ namespace Bit.Api.Billing.Attributes; -public class PaymentMethodTypeValidationAttribute : StringMatchesAttribute +public class TokenizedPaymentMethodTypeValidationAttribute : StringMatchesAttribute { private static readonly string[] _acceptedValues = ["bankAccount", "card", "payPal"]; - public PaymentMethodTypeValidationAttribute() : base(_acceptedValues) + public TokenizedPaymentMethodTypeValidationAttribute() : base(_acceptedValues) { ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}"; } diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index 7abcf8c357..243f4d3c53 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -1,7 +1,5 @@ -#nullable enable -using Bit.Api.Billing.Models.Responses; +using Bit.Api.Billing.Models.Responses; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Requests; using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; @@ -12,10 +10,11 @@ namespace Bit.Api.Billing.Controllers; [Route("accounts/billing")] [Authorize("Application")] public class AccountsBillingController( - IPaymentService paymentService, + IStripePaymentService paymentService, IUserService userService, IPaymentHistoryService paymentHistoryService) : Controller { + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("history")] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingHistoryAsync() @@ -30,20 +29,7 @@ public class AccountsBillingController( return new BillingHistoryResponseModel(billingInfo); } - [HttpGet("payment-method")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetPaymentMethodAsync() - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var billingInfo = await paymentService.GetBillingAsync(user); - return new BillingPaymentResponseModel(billingInfo); - } - + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromQuery] string? status = null, [FromQuery] string? startAfter = null) { @@ -62,6 +48,7 @@ public class AccountsBillingController( return TypedResults.Ok(invoices); } + // TODO: Migrate to Query / AccountBillingVNextController [HttpGet("transactions")] public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter = null) { @@ -78,18 +65,4 @@ public class AccountsBillingController( return TypedResults.Ok(transactions); } - - [HttpPost("preview-invoice")] - public async Task PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, user.GatewayCustomerId, user.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } } diff --git a/src/Api/Billing/Controllers/AccountsController.cs b/src/Api/Billing/Controllers/AccountsController.cs index 9411d454aa..5d3e095fdd 100644 --- a/src/Api/Billing/Controllers/AccountsController.cs +++ b/src/Api/Billing/Controllers/AccountsController.cs @@ -1,13 +1,14 @@ -#nullable enable -using Bit.Api.Models.Request; +using Bit.Api.Models.Request; using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Response; using Bit.Api.Utilities; +using Bit.Core; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; @@ -21,8 +22,12 @@ namespace Bit.Api.Billing.Controllers; [Authorize("Application")] public class AccountsController( IUserService userService, - ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery) : Controller + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IUserAccountKeysQuery userAccountKeysQuery, + IFeatureService featureService, + ILicensingService licensingService) : Controller { + // TODO: Remove when pm-24996-implement-upgrade-from-free-dialog is removed [HttpPost("premium")] public async Task PostPremiumAsync( PremiumRequestModel model, @@ -58,8 +63,9 @@ public class AccountsController( var userTwoFactorEnabled = await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var userHasPremiumFromOrganization = await userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var accountKeys = await userAccountKeysQuery.Run(user); - var profile = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, + var profile = new ProfileResponseModel(user, accountKeys, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); return new PaymentResponseModel { @@ -69,10 +75,11 @@ public class AccountsController( }; } + // TODO: Migrate to Query / AccountBillingVNextController as part of Premium -> Organization upgrade work. [HttpGet("subscription")] public async Task GetSubscriptionAsync( [FromServices] GlobalSettings globalSettings, - [FromServices] IPaymentService paymentService) + [FromServices] IStripePaymentService paymentService) { var user = await userService.GetUserByPrincipalAsync(User); if (user == null) @@ -80,16 +87,26 @@ public class AccountsController( throw new UnauthorizedAccessException(); } - if (!globalSettings.SelfHosted && user.Gateway != null) + // Only cloud-hosted users with payment gateways have subscription and discount information + if (!globalSettings.SelfHosted) { - var subscriptionInfo = await paymentService.GetSubscriptionAsync(user); - var license = await userService.GenerateLicenseAsync(user, subscriptionInfo); - return new SubscriptionResponseModel(user, subscriptionInfo, license); - } - else if (!globalSettings.SelfHosted) - { - var license = await userService.GenerateLicenseAsync(user); - return new SubscriptionResponseModel(user, license); + if (user.Gateway != null) + { + // Note: PM23341_Milestone_2 is the feature flag for the overall Milestone 2 initiative (PM-23341). + // This specific implementation (PM-26682) adds discount display functionality as part of that initiative. + // The feature flag controls the broader Milestone 2 feature set, not just this specific task. + var includeMilestone2Discount = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); + var subscriptionInfo = await paymentService.GetSubscriptionAsync(user); + var license = await userService.GenerateLicenseAsync(user, subscriptionInfo); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license); + return new SubscriptionResponseModel(user, subscriptionInfo, license, claimsPrincipal, includeMilestone2Discount); + } + else + { + var license = await userService.GenerateLicenseAsync(user); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(license); + return new SubscriptionResponseModel(user, null, license, claimsPrincipal); + } } else { @@ -97,29 +114,7 @@ public class AccountsController( } } - [HttpPost("payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostPaymentAsync([FromBody] PaymentRequestModel model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - await userService.ReplacePaymentMethodAsync(user, model.PaymentToken, model.PaymentMethodType!.Value, - new TaxInfo - { - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressCountry = model.Country, - BillingAddressPostalCode = model.PostalCode, - TaxIdNumber = model.TaxId - }); - } - + // TODO: Migrate to Command / AccountBillingVNextController as PUT /account/billing/vnext/subscription [HttpPost("storage")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostStorageAsync([FromBody] StorageRequestModel model) @@ -134,8 +129,11 @@ public class AccountsController( return new PaymentResponseModel { Success = true, PaymentIntentClientSecret = result }; } - - + /* + * TODO: A new version of this exists in the AccountBillingVNextController. + * The individual-self-hosting-license-uploader.component needs to be updated to use it. + * Then, this can be removed. + */ [HttpPost("license")] [SelfHosted(SelfHostedOnly = true)] public async Task PostLicenseAsync(LicenseRequestModel model) @@ -155,6 +153,7 @@ public class AccountsController( await userService.UpdateLicenseAsync(user, license); } + // TODO: Migrate to Command / AccountBillingVNextController as DELETE /account/billing/vnext/subscription [HttpPost("cancel")] public async Task PostCancelAsync( [FromBody] SubscriptionCancellationRequestModel request, @@ -172,6 +171,7 @@ public class AccountsController( user.IsExpired()); } + // TODO: Migrate to Command / AccountBillingVNextController as POST /account/billing/vnext/subscription/reinstate [HttpPost("reinstate-premium")] [SelfHosted(NotSelfHostedOnly = true)] public async Task PostReinstateAsync() @@ -185,41 +185,6 @@ public class AccountsController( await userService.ReinstatePremiumAsync(user); } - [HttpGet("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfoAsync( - [FromServices] IPaymentService paymentService) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = await paymentService.GetTaxInfoAsync(user); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfoAsync( - [FromBody] TaxInfoUpdateRequestModel model, - [FromServices] IPaymentService paymentService) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var taxInfo = new TaxInfo - { - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await paymentService.SaveTaxInfoAsync(user, taxInfo); - } - private async Task> GetOrganizationIdsClaimingUserAsync(Guid userId) { var organizationsClaimingUser = await userService.GetOrganizationsClaimingUserAsync(userId); diff --git a/src/Api/Billing/Controllers/InvoicesController.cs b/src/Api/Billing/Controllers/InvoicesController.cs deleted file mode 100644 index 30ea975e09..0000000000 --- a/src/Api/Billing/Controllers/InvoicesController.cs +++ /dev/null @@ -1,45 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("invoices")] -[Authorize("Application")] -public class InvoicesController : BaseBillingController -{ - [HttpPost("preview-organization")] - public async Task PreviewInvoiceAsync( - [FromBody] PreviewOrganizationInvoiceRequestBody model, - [FromServices] ICurrentContext currentContext, - [FromServices] IOrganizationRepository organizationRepository, - [FromServices] IPaymentService paymentService) - { - Organization organization = null; - if (model.OrganizationId != default) - { - if (!await currentContext.EditPaymentMethods(model.OrganizationId)) - { - return Error.Unauthorized(); - } - - organization = await organizationRepository.GetByIdAsync(model.OrganizationId); - if (organization == null) - { - return Error.NotFound(); - } - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, organization?.GatewayCustomerId, - organization?.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } -} diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 1d6bf51661..e06d946ea0 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -5,7 +5,6 @@ using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -19,10 +18,10 @@ public class OrganizationBillingController( ICurrentContext currentContext, IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, - ISubscriberService subscriberService, + IStripePaymentService paymentService, IPaymentHistoryService paymentHistoryService) : BaseBillingController { + // TODO: Remove when pm-25379-use-new-organization-metadata-structure is removed. [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) { @@ -38,11 +37,10 @@ public class OrganizationBillingController( return Error.NotFound(); } - var response = OrganizationMetadataResponse.From(metadata); - - return TypedResults.Ok(response); + return TypedResults.Ok(metadata); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("history")] public async Task GetHistoryAsync([FromRoute] Guid organizationId) { @@ -63,6 +61,7 @@ public class OrganizationBillingController( return TypedResults.Ok(billingInfo); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid organizationId, [FromQuery] string? status = null, [FromQuery] string? startAfter = null) { @@ -87,6 +86,7 @@ public class OrganizationBillingController( return TypedResults.Ok(invoices); } + // TODO: Migrate to Query / OrganizationBillingVNextController [HttpGet("transactions")] public async Task GetTransactionsAsync([FromRoute] Guid organizationId, [FromQuery] DateTime? startAfter = null) { @@ -110,6 +110,7 @@ public class OrganizationBillingController( return TypedResults.Ok(transactions); } + // TODO: Can be removed once we do away with the organization-plans.component. [HttpGet] [SelfHosted(NotSelfHostedOnly = true)] public async Task GetBillingAsync(Guid organizationId) @@ -133,127 +134,7 @@ public class OrganizationBillingController( return TypedResults.Ok(response); } - [HttpGet("payment-method")] - public async Task GetPaymentMethodAsync([FromRoute] Guid organizationId) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var paymentMethod = await subscriberService.GetPaymentMethod(organization); - - var response = PaymentMethodResponse.From(paymentMethod); - - return TypedResults.Ok(response); - } - - [HttpPut("payment-method")] - public async Task UpdatePaymentMethodAsync( - [FromRoute] Guid organizationId, - [FromBody] UpdatePaymentMethodRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain(); - - var taxInformation = requestBody.TaxInformation.ToDomain(); - - await organizationBillingService.UpdatePaymentMethod(organization, tokenizedPaymentSource, taxInformation); - - return TypedResults.Ok(); - } - - [HttpPost("payment-method/verify-bank-account")] - public async Task VerifyBankAccountAsync( - [FromRoute] Guid organizationId, - [FromBody] VerifyBankAccountRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM")) - { - return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'"); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - await subscriberService.VerifyBankAccount(organization, requestBody.DescriptorCode); - - return TypedResults.Ok(); - } - - [HttpGet("tax-information")] - public async Task GetTaxInformationAsync([FromRoute] Guid organizationId) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var taxInformation = await subscriberService.GetTaxInformation(organization); - - var response = TaxInformationResponse.From(taxInformation); - - return TypedResults.Ok(response); - } - - [HttpPut("tax-information")] - public async Task UpdateTaxInformationAsync( - [FromRoute] Guid organizationId, - [FromBody] TaxInformationRequestBody requestBody) - { - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - - if (organization == null) - { - return Error.NotFound(); - } - - var taxInformation = requestBody.ToDomain(); - - await subscriberService.UpdateTaxInformation(organization, taxInformation); - - return TypedResults.Ok(); - } - + // TODO: Migrate to Command / OrganizationBillingVNextController [HttpPost("setup-business-unit")] [SelfHosted(NotSelfHostedOnly = true)] public async Task SetupBusinessUnitAsync( @@ -282,6 +163,7 @@ public class OrganizationBillingController( return TypedResults.Ok(providerId); } + // TODO: Migrate to Command / OrganizationBillingVNextController [HttpPost("change-frequency")] [SelfHosted(NotSelfHostedOnly = true)] public async Task ChangePlanSubscriptionFrequencyAsync( diff --git a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs index 8c202752de..7ca85d52a8 100644 --- a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs @@ -89,19 +89,6 @@ public class OrganizationSponsorshipsController : Controller throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator."); } - if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) - { - if (model.IsAdminInitiated.GetValueOrDefault()) - { - throw new BadRequestException(); - } - - if (!string.IsNullOrWhiteSpace(model.Notes)) - { - model.Notes = null; - } - } - var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( sponsoringOrg, await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), diff --git a/src/Api/Billing/Controllers/OrganizationsController.cs b/src/Api/Billing/Controllers/OrganizationsController.cs index 5494c5a90e..bca5605a8c 100644 --- a/src/Api/Billing/Controllers/OrganizationsController.cs +++ b/src/Api/Billing/Controllers/OrganizationsController.cs @@ -19,7 +19,6 @@ using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; using Bit.Core.Services; @@ -37,7 +36,7 @@ public class OrganizationsController( IOrganizationUserRepository organizationUserRepository, IOrganizationService organizationService, IUserService userService, - IPaymentService paymentService, + IStripePaymentService paymentService, ICurrentContext currentContext, IGetCloudOrganizationLicenseQuery getCloudOrganizationLicenseQuery, GlobalSettings globalSettings, @@ -67,7 +66,8 @@ public class OrganizationsController( if (globalSettings.SelfHosted) { var orgLicense = await licensingService.ReadOrganizationLicenseAsync(organization); - return new OrganizationSubscriptionResponseModel(organization, orgLicense); + var claimsPrincipal = licensingService.GetClaimsPrincipalFromLicense(orgLicense); + return new OrganizationSubscriptionResponseModel(organization, orgLicense, claimsPrincipal); } var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -248,53 +248,6 @@ public class OrganizationsController( await organizationService.ReinstateSubscriptionAsync(id); } - [HttpGet("{id:guid}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task GetTaxInfo(Guid id) - { - if (!await currentContext.OrganizationOwner(id)) - { - throw new NotFoundException(); - } - - var organization = await organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = await paymentService.GetTaxInfoAsync(organization); - return new TaxInfoResponseModel(taxInfo); - } - - [HttpPut("{id:guid}/tax")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PutTaxInfo(Guid id, [FromBody] ExpandedTaxInfoUpdateRequestModel model) - { - if (!await currentContext.OrganizationOwner(id)) - { - throw new NotFoundException(); - } - - var organization = await organizationRepository.GetByIdAsync(id); - if (organization == null) - { - throw new NotFoundException(); - } - - var taxInfo = new TaxInfo - { - TaxIdNumber = model.TaxId, - BillingAddressLine1 = model.Line1, - BillingAddressLine2 = model.Line2, - BillingAddressCity = model.City, - BillingAddressState = model.State, - BillingAddressPostalCode = model.PostalCode, - BillingAddressCountry = model.Country, - }; - await paymentService.SaveTaxInfoAsync(organization, taxInfo); - } - /// /// Tries to grant owner access to the Secrets Manager for the organization /// diff --git a/src/Api/Controllers/PlansController.cs b/src/Api/Billing/Controllers/PlansController.cs similarity index 66% rename from src/Api/Controllers/PlansController.cs rename to src/Api/Billing/Controllers/PlansController.cs index 11b070fb66..f9b5274780 100644 --- a/src/Api/Controllers/PlansController.cs +++ b/src/Api/Billing/Controllers/PlansController.cs @@ -3,10 +3,10 @@ using Bit.Core.Billing.Pricing; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; +namespace Bit.Api.Billing.Controllers; [Route("plans")] -[Authorize("Web")] +[Authorize("Application")] public class PlansController( IPricingClient pricingClient) : Controller { @@ -18,4 +18,11 @@ public class PlansController( var responses = plans.Select(plan => new PlanResponseModel(plan)); return new ListResponseModel(responses); } + + [HttpGet("premium")] + public async Task GetPremiumPlanAsync() + { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + return TypedResults.Ok(premiumPlan); + } } diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index f7d0593812..dfa705a329 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -1,7 +1,6 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Pricing; @@ -9,7 +8,6 @@ using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Models.BitStripe; using Bit.Core.Services; @@ -34,6 +32,7 @@ public class ProviderBillingController( IStripeAdapter stripeAdapter, IUserService userService) : BaseProviderController(currentContext, logger, providerRepository, userService) { + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("invoices")] public async Task GetInvoicesAsync([FromRoute] Guid providerId) { @@ -44,7 +43,7 @@ public class ProviderBillingController( return result; } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = provider.GatewayCustomerId }); @@ -54,6 +53,7 @@ public class ProviderBillingController( return TypedResults.Ok(response); } + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("invoices/{invoiceId}")] public async Task GenerateClientInvoiceReportAsync([FromRoute] Guid providerId, string invoiceId) { @@ -76,51 +76,7 @@ public class ProviderBillingController( "text/csv"); } - [HttpPut("payment-method")] - public async Task UpdatePaymentMethodAsync( - [FromRoute] Guid providerId, - [FromBody] UpdatePaymentMethodRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var tokenizedPaymentSource = requestBody.PaymentSource.ToDomain(); - var taxInformation = requestBody.TaxInformation.ToDomain(); - - await providerBillingService.UpdatePaymentMethod( - provider, - tokenizedPaymentSource, - taxInformation); - - return TypedResults.Ok(); - } - - [HttpPost("payment-method/verify-bank-account")] - public async Task VerifyBankAccountAsync( - [FromRoute] Guid providerId, - [FromBody] VerifyBankAccountRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - if (requestBody.DescriptorCode.Length != 6 || !requestBody.DescriptorCode.StartsWith("SM")) - { - return Error.BadRequest("Statement descriptor should be a 6-character value that starts with 'SM'"); - } - - await subscriberService.VerifyBankAccount(provider, requestBody.DescriptorCode); - - return TypedResults.Ok(); - } - + // TODO: Migrate to Query / ProviderBillingVNextController [HttpGet("subscription")] public async Task GetSubscriptionAsync([FromRoute] Guid providerId) { @@ -131,8 +87,8 @@ public class ProviderBillingController( return result; } - var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, - new SubscriptionGetOptions { Expand = ["customer.tax_ids", "test_clock"] }); + var subscription = await stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.tax_ids", "discounts", "test_clock"] }); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); @@ -140,7 +96,7 @@ public class ProviderBillingController( { var plan = await pricingClient.GetPlanOrThrow(providerPlan.PlanType); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, plan.Type); - var price = await stripeAdapter.PriceGetAsync(priceId); + var price = await stripeAdapter.GetPriceAsync(priceId); var unitAmount = price.UnitAmountDecimal.HasValue ? price.UnitAmountDecimal.Value / 100M @@ -172,53 +128,4 @@ public class ProviderBillingController( return TypedResults.Ok(response); } - - [HttpGet("tax-information")] - public async Task GetTaxInformationAsync([FromRoute] Guid providerId) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - var taxInformation = await subscriberService.GetTaxInformation(provider); - - var response = TaxInformationResponse.From(taxInformation); - - return TypedResults.Ok(response); - } - - [HttpPut("tax-information")] - public async Task UpdateTaxInformationAsync( - [FromRoute] Guid providerId, - [FromBody] TaxInformationRequestBody requestBody) - { - var (provider, result) = await TryGetBillableProviderForAdminOperation(providerId); - - if (provider == null) - { - return result; - } - - if (requestBody is not { Country: not null, PostalCode: not null }) - { - return Error.BadRequest("Country and postal code are required to update your tax information."); - } - - var taxInformation = new TaxInformation( - requestBody.Country, - requestBody.PostalCode, - requestBody.TaxId, - requestBody.TaxIdType, - requestBody.Line1, - requestBody.Line2, - requestBody.City, - requestBody.State); - - await subscriberService.UpdateTaxInformation(provider, taxInformation); - - return TypedResults.Ok(); - } } diff --git a/src/Api/Billing/Controllers/StripeController.cs b/src/Api/Billing/Controllers/StripeController.cs index 15fccd16f4..6cb10e3165 100644 --- a/src/Api/Billing/Controllers/StripeController.cs +++ b/src/Api/Billing/Controllers/StripeController.cs @@ -1,5 +1,5 @@ -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Tax.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -28,7 +28,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } @@ -43,7 +43,7 @@ public class StripeController( Usage = "off_session" }; - var setupIntent = await stripeAdapter.SetupIntentCreate(options); + var setupIntent = await stripeAdapter.CreateSetupIntentAsync(options); return TypedResults.Ok(setupIntent.ClientSecret); } diff --git a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs index 2f825f2cb9..64ec068a5e 100644 --- a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs @@ -4,6 +4,7 @@ using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Subscriptions; using Bit.Api.Billing.Models.Requirements; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Organizations.Queries; @@ -25,6 +26,7 @@ public class OrganizationBillingVNextController( ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, IGetBillingAddressQuery getBillingAddressQuery, IGetCreditQuery getCreditQuery, + IGetOrganizationMetadataQuery getOrganizationMetadataQuery, IGetOrganizationWarningsQuery getOrganizationWarningsQuery, IGetPaymentMethodQuery getPaymentMethodQuery, IRestartSubscriptionCommand restartSubscriptionCommand, @@ -113,6 +115,23 @@ public class OrganizationBillingVNextController( return Handle(result); } + [Authorize] + [HttpGet("metadata")] + [RequireFeature(FeatureFlagKeys.PM25379_UseNewOrganizationMetadataStructure)] + [InjectOrganization] + public async Task GetMetadataAsync( + [BindNever] Organization organization) + { + var metadata = await getOrganizationMetadataQuery.Run(organization); + + if (metadata == null) + { + return TypedResults.NotFound(); + } + + return TypedResults.Ok(metadata); + } + [Authorize] [HttpGet("warnings")] [InjectOrganization] diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs similarity index 92% rename from src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs rename to src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs index 973a7d99a1..b86f29bdbc 100644 --- a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs +++ b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingVNextController.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Api.Utilities; using Bit.Core; @@ -17,7 +16,7 @@ namespace Bit.Api.Billing.Controllers.VNext; [Authorize("Application")] [Route("account/billing/vnext/self-host")] [SelfHosted(SelfHostedOnly = true)] -public class SelfHostedAccountBillingController( +public class SelfHostedAccountBillingVNextController( ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController { [HttpPost("license")] diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs new file mode 100644 index 0000000000..625a97c998 --- /dev/null +++ b/src/Api/Billing/Controllers/VNext/SelfHostedOrganizationBillingVNextController.cs @@ -0,0 +1,35 @@ +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.AdminConsole.Authorization.Requirements; +using Bit.Api.Billing.Attributes; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; + +namespace Bit.Api.Billing.Controllers.VNext; + +[Authorize("Application")] +[Route("organizations/{organizationId:guid}/billing/vnext/self-host")] +[SelfHosted(SelfHostedOnly = true)] +public class SelfHostedOrganizationBillingVNextController( + IGetOrganizationMetadataQuery getOrganizationMetadataQuery) : BaseBillingController +{ + [Authorize] + [HttpGet("metadata")] + [RequireFeature(FeatureFlagKeys.PM25379_UseNewOrganizationMetadataStructure)] + [InjectOrganization] + public async Task GetMetadataAsync([BindNever] Organization organization) + { + var metadata = await getOrganizationMetadataQuery.Run(organization); + + if (metadata == null) + { + return TypedResults.NotFound(); + } + + return TypedResults.Ok(metadata); + } +} diff --git a/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs b/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs index 2fec3bd61d..9979141b6d 100644 --- a/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs +++ b/src/Api/Billing/Models/Requests/KeyPairRequestBody.cs @@ -2,6 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using Bit.Core.KeyManagement.Models.Data; namespace Bit.Api.Billing.Models.Requests; @@ -12,4 +13,11 @@ public class KeyPairRequestBody public string PublicKey { get; set; } [Required(ErrorMessage = "'encryptedPrivateKey' must be provided")] public string EncryptedPrivateKey { get; set; } + + public PublicKeyEncryptionKeyPairData ToPublicKeyEncryptionKeyPairData() + { + return new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: EncryptedPrivateKey, + publicKey: PublicKey); + } } diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs index b0e415c262..1311805ad4 100644 --- a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs @@ -7,7 +7,7 @@ namespace Bit.Api.Billing.Models.Requests.Payment; public class MinimalTokenizedPaymentMethodRequest { [Required] - [PaymentMethodTypeValidation] + [TokenizedPaymentMethodTypeValidation] public required string Type { get; set; } [Required] diff --git a/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs new file mode 100644 index 0000000000..d15bc73778 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs @@ -0,0 +1,21 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Attributes; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public class NonTokenizedPaymentMethodRequest +{ + [Required] + [NonTokenizedPaymentMethodTypeValidation] + public required string Type { get; set; } + + public NonTokenizedPaymentMethod ToDomain() + { + return Type switch + { + "accountCredit" => new NonTokenizedPaymentMethod { Type = NonTokenizablePaymentMethodType.AccountCredit }, + _ => throw new InvalidOperationException($"Invalid value for {nameof(NonTokenizedPaymentMethod)}.{nameof(NonTokenizedPaymentMethod.Type)}") + }; + } +} diff --git a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs index 03f20ec9c1..0f9198fdad 100644 --- a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs @@ -4,10 +4,10 @@ using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Premium; -public class PremiumCloudHostedSubscriptionRequest +public class PremiumCloudHostedSubscriptionRequest : IValidatableObject { - [Required] - public required MinimalTokenizedPaymentMethodRequest TokenizedPaymentMethod { get; set; } + public MinimalTokenizedPaymentMethodRequest? TokenizedPaymentMethod { get; set; } + public NonTokenizedPaymentMethodRequest? NonTokenizedPaymentMethod { get; set; } [Required] public required MinimalBillingAddressRequest BillingAddress { get; set; } @@ -15,11 +15,38 @@ public class PremiumCloudHostedSubscriptionRequest [Range(0, 99)] public short AdditionalStorageGb { get; set; } = 0; - public (TokenizedPaymentMethod, BillingAddress, short) ToDomain() + + public (PaymentMethod, BillingAddress, short) ToDomain() { - var paymentMethod = TokenizedPaymentMethod.ToDomain(); + // Check if TokenizedPaymentMethod or NonTokenizedPaymentMethod is provided. + var tokenizedPaymentMethod = TokenizedPaymentMethod?.ToDomain(); + var nonTokenizedPaymentMethod = NonTokenizedPaymentMethod?.ToDomain(); + + PaymentMethod paymentMethod = tokenizedPaymentMethod != null + ? tokenizedPaymentMethod + : nonTokenizedPaymentMethod!; + var billingAddress = BillingAddress.ToDomain(); return (paymentMethod, billingAddress, AdditionalStorageGb); } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (TokenizedPaymentMethod == null && NonTokenizedPaymentMethod == null) + { + yield return new ValidationResult( + "Either TokenizedPaymentMethod or NonTokenizedPaymentMethod must be provided.", + new[] { nameof(TokenizedPaymentMethod), nameof(NonTokenizedPaymentMethod) } + ); + } + + if (TokenizedPaymentMethod != null && NonTokenizedPaymentMethod != null) + { + yield return new ValidationResult( + "Only one of TokenizedPaymentMethod or NonTokenizedPaymentMethod can be provided.", + new[] { nameof(TokenizedPaymentMethod), nameof(NonTokenizedPaymentMethod) } + ); + } + } } diff --git a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs deleted file mode 100644 index a1b754a9dc..0000000000 --- a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs +++ /dev/null @@ -1,31 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Requests; - -public class TaxInformationRequestBody -{ - [Required] - public string Country { get; set; } - [Required] - public string PostalCode { get; set; } - public string TaxId { get; set; } - public string TaxIdType { get; set; } - public string Line1 { get; set; } - public string Line2 { get; set; } - public string City { get; set; } - public string State { get; set; } - - public TaxInformation ToDomain() => new( - Country, - PostalCode, - TaxId, - TaxIdType, - Line1, - Line2, - City, - State); -} diff --git a/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs b/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs deleted file mode 100644 index b469ce2576..0000000000 --- a/src/Api/Billing/Models/Requests/TokenizedPaymentSourceRequestBody.cs +++ /dev/null @@ -1,25 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Api.Utilities; -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Api.Billing.Models.Requests; - -public class TokenizedPaymentSourceRequestBody -{ - [Required] - [EnumMatches( - PaymentMethodType.BankAccount, - PaymentMethodType.Card, - PaymentMethodType.PayPal, - ErrorMessage = "'type' must be BankAccount, Card or PayPal")] - public PaymentMethodType Type { get; set; } - - [Required] - public string Token { get; set; } - - public TokenizedPaymentSource ToDomain() => new(Type, Token); -} diff --git a/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs b/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs deleted file mode 100644 index 05ab1e34c9..0000000000 --- a/src/Api/Billing/Models/Requests/UpdatePaymentMethodRequestBody.cs +++ /dev/null @@ -1,15 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Billing.Models.Requests; - -public class UpdatePaymentMethodRequestBody -{ - [Required] - public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } - - [Required] - public TaxInformationRequestBody TaxInformation { get; set; } -} diff --git a/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs b/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs deleted file mode 100644 index e248d55dde..0000000000 --- a/src/Api/Billing/Models/Requests/VerifyBankAccountRequestBody.cs +++ /dev/null @@ -1,12 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Billing.Models.Requests; - -public class VerifyBankAccountRequestBody -{ - [Required] - public string DescriptorCode { get; set; } -} diff --git a/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs b/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs deleted file mode 100644 index f305e41c4f..0000000000 --- a/src/Api/Billing/Models/Responses/BillingPaymentResponseModel.cs +++ /dev/null @@ -1,20 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Billing.Models; -using Bit.Core.Models.Api; - -namespace Bit.Api.Billing.Models.Responses; - -public class BillingPaymentResponseModel : ResponseModel -{ - public BillingPaymentResponseModel(BillingInfo billing) - : base("billingPayment") - { - Balance = billing.Balance; - PaymentSource = billing.PaymentSource != null ? new BillingSource(billing.PaymentSource) : null; - } - - public decimal Balance { get; set; } - public BillingSource PaymentSource { get; set; } -} diff --git a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs deleted file mode 100644 index a13f267c3b..0000000000 --- a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs +++ /dev/null @@ -1,31 +0,0 @@ -using Bit.Core.Billing.Organizations.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record OrganizationMetadataResponse( - bool IsEligibleForSelfHost, - bool IsManaged, - bool IsOnSecretsManagerStandalone, - bool IsSubscriptionUnpaid, - bool HasSubscription, - bool HasOpenInvoice, - bool IsSubscriptionCanceled, - DateTime? InvoiceDueDate, - DateTime? InvoiceCreatedDate, - DateTime? SubPeriodEndDate, - int OrganizationOccupiedSeats) -{ - public static OrganizationMetadataResponse From(OrganizationMetadata metadata) - => new( - metadata.IsEligibleForSelfHost, - metadata.IsManaged, - metadata.IsOnSecretsManagerStandalone, - metadata.IsSubscriptionUnpaid, - metadata.HasSubscription, - metadata.HasOpenInvoice, - metadata.IsSubscriptionCanceled, - metadata.InvoiceDueDate, - metadata.InvoiceCreatedDate, - metadata.SubPeriodEndDate, - metadata.OrganizationOccupiedSeats); -} diff --git a/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs b/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs deleted file mode 100644 index a54ac0a876..0000000000 --- a/src/Api/Billing/Models/Responses/PaymentMethodResponse.cs +++ /dev/null @@ -1,18 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record PaymentMethodResponse( - decimal AccountCredit, - PaymentSource PaymentSource, - string SubscriptionStatus, - TaxInformation TaxInformation) -{ - public static PaymentMethodResponse From(PaymentMethod paymentMethod) => - new( - paymentMethod.AccountCredit, - paymentMethod.PaymentSource, - paymentMethod.SubscriptionStatus, - paymentMethod.TaxInformation); -} diff --git a/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs b/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs deleted file mode 100644 index 2c9a63b1d0..0000000000 --- a/src/Api/Billing/Models/Responses/PaymentSourceResponse.cs +++ /dev/null @@ -1,16 +0,0 @@ -using Bit.Core.Billing.Models; -using Bit.Core.Enums; - -namespace Bit.Api.Billing.Models.Responses; - -public record PaymentSourceResponse( - PaymentMethodType Type, - string Description, - bool NeedsVerification) -{ - public static PaymentSourceResponse From(PaymentSource paymentMethod) - => new( - paymentMethod.Type, - paymentMethod.Description, - paymentMethod.NeedsVerification); -} diff --git a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs index e5b868af9a..4b78127240 100644 --- a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs +++ b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Tax.Models; @@ -10,7 +11,7 @@ namespace Bit.Api.Billing.Models.Responses; public record ProviderSubscriptionResponse( string Status, - DateTime CurrentPeriodEndDate, + DateTime? CurrentPeriodEndDate, decimal? DiscountPercentage, string CollectionMethod, IEnumerable Plans, @@ -51,10 +52,12 @@ public record ProviderSubscriptionResponse( var accountCredit = Convert.ToDecimal(subscription.Customer?.Balance) * -1 / 100; + var discount = subscription.Customer?.Discount ?? subscription.Discounts?.FirstOrDefault(); + return new ProviderSubscriptionResponse( subscription.Status, - subscription.CurrentPeriodEnd, - subscription.Customer?.Discount?.Coupon?.PercentOff, + subscription.GetCurrentPeriodEnd(), + discount?.Coupon?.PercentOff, subscription.CollectionMethod, providerPlanResponses, accountCredit, diff --git a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs b/src/Api/Billing/Models/Responses/TaxInformationResponse.cs deleted file mode 100644 index 59e4934751..0000000000 --- a/src/Api/Billing/Models/Responses/TaxInformationResponse.cs +++ /dev/null @@ -1,23 +0,0 @@ -using Bit.Core.Billing.Tax.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record TaxInformationResponse( - string Country, - string PostalCode, - string TaxId, - string Line1, - string Line2, - string City, - string State) -{ - public static TaxInformationResponse From(TaxInformation taxInformation) - => new( - taxInformation.Country, - taxInformation.PostalCode, - taxInformation.TaxId, - taxInformation.Line1, - taxInformation.Line2, - taxInformation.City, - taxInformation.State); -} diff --git a/src/Api/Controllers/MiscController.cs b/src/Api/Controllers/MiscController.cs deleted file mode 100644 index 6f23a27fbf..0000000000 --- a/src/Api/Controllers/MiscController.cs +++ /dev/null @@ -1,45 +0,0 @@ -using Bit.Api.Models.Request; -using Bit.Core.Settings; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; -using Stripe; - -namespace Bit.Api.Controllers; - -public class MiscController : Controller -{ - private readonly BitPayClient _bitPayClient; - private readonly GlobalSettings _globalSettings; - - public MiscController( - BitPayClient bitPayClient, - GlobalSettings globalSettings) - { - _bitPayClient = bitPayClient; - _globalSettings = globalSettings; - } - - [Authorize("Application")] - [HttpPost("~/bitpay-invoice")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) - { - var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); - return invoice.Url; - } - - [Authorize("Application")] - [HttpPost("~/setup-payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSetupPayment() - { - var options = new SetupIntentCreateOptions - { - Usage = "off_session" - }; - var service = new SetupIntentService(); - var setupIntent = await service.CreateAsync(options); - return setupIntent.ClientSecret; - } -} diff --git a/src/Api/Controllers/PhishingDomainsController.cs b/src/Api/Controllers/PhishingDomainsController.cs deleted file mode 100644 index f0c1a65648..0000000000 --- a/src/Api/Controllers/PhishingDomainsController.cs +++ /dev/null @@ -1,34 +0,0 @@ -using Bit.Core; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Controllers; - -[Route("phishing-domains")] -public class PhishingDomainsController(IPhishingDomainRepository phishingDomainRepository, IFeatureService featureService) : Controller -{ - [HttpGet] - public async Task>> GetPhishingDomainsAsync() - { - if (!featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - return NotFound(); - } - - var domains = await phishingDomainRepository.GetActivePhishingDomainsAsync(); - return Ok(domains); - } - - [HttpGet("checksum")] - public async Task> GetChecksumAsync() - { - if (!featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - return NotFound(); - } - - var checksum = await phishingDomainRepository.GetCurrentChecksumAsync(); - return Ok(checksum); - } -} diff --git a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs index 198438201c..6865bc06da 100644 --- a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs @@ -55,19 +55,6 @@ public class SelfHostedOrganizationSponsorshipsController : Controller [HttpPost("{sponsoringOrgId}/families-for-enterprise")] public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) { - if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) - { - if (model.IsAdminInitiated.GetValueOrDefault()) - { - throw new BadRequestException(); - } - - if (!string.IsNullOrWhiteSpace(model.Notes)) - { - model.Notes = null; - } - } - await _offerSponsorshipCommand.CreateSponsorshipAsync( await _organizationRepository.GetByIdAsync(sponsoringOrgId), await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), diff --git a/src/Api/Controllers/UsersController.cs b/src/Api/Controllers/UsersController.cs deleted file mode 100644 index 4dfd047d37..0000000000 --- a/src/Api/Controllers/UsersController.cs +++ /dev/null @@ -1,33 +0,0 @@ -using Bit.Api.Models.Response; -using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Controllers; - -[Route("users")] -[Authorize("Application")] -public class UsersController : Controller -{ - private readonly IUserRepository _userRepository; - - public UsersController( - IUserRepository userRepository) - { - _userRepository = userRepository; - } - - [HttpGet("{id}/public-key")] - public async Task Get(string id) - { - var guidId = new Guid(id); - var key = await _userRepository.GetPublicKeyAsync(guidId); - if (key == null) - { - throw new NotFoundException(); - } - - return new UserKeyResponseModel(guidId, key); - } -} diff --git a/src/Api/AdminConsole/Controllers/EventsController.cs b/src/Api/Dirt/Controllers/EventsController.cs similarity index 95% rename from src/Api/AdminConsole/Controllers/EventsController.cs rename to src/Api/Dirt/Controllers/EventsController.cs index f868f0b3b6..1ac83c1316 100644 --- a/src/Api/AdminConsole/Controllers/EventsController.cs +++ b/src/Api/Dirt/Controllers/EventsController.cs @@ -1,8 +1,10 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Api.Dirt.Models.Response; using Bit.Api.Models.Response; using Bit.Api.Utilities; +using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Context; using Bit.Core.Enums; @@ -16,7 +18,7 @@ using Bit.Core.Vault.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; +namespace Bit.Api.Dirt.Controllers; [Route("events")] [Authorize("Application")] @@ -31,10 +33,11 @@ public class EventsController : Controller private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; - public EventsController( - IUserService userService, + public EventsController(IUserService userService, ICipherRepository cipherRepository, IOrganizationUserRepository organizationUserRepository, IProviderUserRepository providerUserRepository, @@ -42,7 +45,9 @@ public class EventsController : Controller ICurrentContext currentContext, ISecretRepository secretRepository, IProjectRepository projectRepository, - IServiceAccountRepository serviceAccountRepository) + IServiceAccountRepository serviceAccountRepository, + ILogger logger, + IFeatureService featureService) { _userService = userService; _cipherRepository = cipherRepository; @@ -53,6 +58,8 @@ public class EventsController : Controller _secretRepository = secretRepository; _projectRepository = projectRepository; _serviceAccountRepository = serviceAccountRepository; + _logger = logger; + _featureService = featureService; } [HttpGet("")] @@ -114,6 +121,9 @@ public class EventsController : Controller var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = continuationToken }); var responses = result.Data.Select(e => new EventResponseModel(e)); + + _logger.LogAggregateData(_featureService, orgId, responses, continuationToken, start, end); + return new ListResponseModel(responses, result.ContinuationToken); } diff --git a/src/Api/Dirt/Controllers/HibpController.cs b/src/Api/Dirt/Controllers/HibpController.cs index d108fdbd4f..8060384502 100644 --- a/src/Api/Dirt/Controllers/HibpController.cs +++ b/src/Api/Dirt/Controllers/HibpController.cs @@ -66,7 +66,10 @@ public class HibpController : Controller } else if (response.StatusCode == HttpStatusCode.NotFound) { - return new NotFoundResult(); + /* 12/1/2025 - Per the HIBP API, If the domain does not have any email addresses in any breaches, + an HTTP 404 response will be returned. API also specifies that "404 Not found is the account could + not be found and has therefore not been pwned". Per REST semantics we will return 200 OK with empty array. */ + return Content("[]", "application/json"); } else if (response.StatusCode == HttpStatusCode.TooManyRequests && retry) { diff --git a/src/Api/Dirt/Controllers/OrganizationIntegrationConfigurationController.cs b/src/Api/Dirt/Controllers/OrganizationIntegrationConfigurationController.cs new file mode 100644 index 0000000000..4296aa3edd --- /dev/null +++ b/src/Api/Dirt/Controllers/OrganizationIntegrationConfigurationController.cs @@ -0,0 +1,93 @@ +using Bit.Api.Dirt.Models.Request; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Dirt.Controllers; + +[Route("organizations/{organizationId:guid}/integrations/{integrationId:guid}/configurations")] +[Authorize("Application")] +public class OrganizationIntegrationConfigurationController( + ICurrentContext currentContext, + ICreateOrganizationIntegrationConfigurationCommand createCommand, + IUpdateOrganizationIntegrationConfigurationCommand updateCommand, + IDeleteOrganizationIntegrationConfigurationCommand deleteCommand, + IGetOrganizationIntegrationConfigurationsQuery getQuery) : Controller +{ + [HttpGet("")] + public async Task> GetAsync( + Guid organizationId, + Guid integrationId) + { + if (!await HasPermission(organizationId)) + { + throw new NotFoundException(); + } + + var configurations = await getQuery.GetManyByIntegrationAsync(organizationId, integrationId); + return configurations + .Select(configuration => new OrganizationIntegrationConfigurationResponseModel(configuration)) + .ToList(); + } + + [HttpPost("")] + public async Task CreateAsync( + Guid organizationId, + Guid integrationId, + [FromBody] OrganizationIntegrationConfigurationRequestModel model) + { + if (!await HasPermission(organizationId)) + { + throw new NotFoundException(); + } + + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var created = await createCommand.CreateAsync(organizationId, integrationId, configuration); + + return new OrganizationIntegrationConfigurationResponseModel(created); + } + + [HttpPut("{configurationId:guid}")] + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + Guid configurationId, + [FromBody] OrganizationIntegrationConfigurationRequestModel model) + { + if (!await HasPermission(organizationId)) + { + throw new NotFoundException(); + } + + var configuration = model.ToOrganizationIntegrationConfiguration(integrationId); + var updated = await updateCommand.UpdateAsync(organizationId, integrationId, configurationId, configuration); + + return new OrganizationIntegrationConfigurationResponseModel(updated); + } + + [HttpDelete("{configurationId:guid}")] + public async Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) + { + if (!await HasPermission(organizationId)) + { + throw new NotFoundException(); + } + + await deleteCommand.DeleteAsync(organizationId, integrationId, configurationId); + } + + [HttpPost("{configurationId:guid}/delete")] + [Obsolete("This endpoint is deprecated. Use DELETE method instead")] + public async Task PostDeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) + { + await DeleteAsync(organizationId, integrationId, configurationId); + } + + private async Task HasPermission(Guid organizationId) + { + return await currentContext.OrganizationOwner(organizationId); + } +} diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs b/src/Api/Dirt/Controllers/OrganizationIntegrationController.cs similarity index 59% rename from src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs rename to src/Api/Dirt/Controllers/OrganizationIntegrationController.cs index a12492949d..960db648c2 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs +++ b/src/Api/Dirt/Controllers/OrganizationIntegrationController.cs @@ -1,23 +1,21 @@ -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; +using Bit.Api.Dirt.Models.Request; +using Bit.Api.Dirt.Models.Response; using Bit.Core.Context; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -#nullable enable +namespace Bit.Api.Dirt.Controllers; -namespace Bit.Api.AdminConsole.Controllers; - -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] [Route("organizations/{organizationId:guid}/integrations")] [Authorize("Application")] public class OrganizationIntegrationController( ICurrentContext currentContext, - IOrganizationIntegrationRepository integrationRepository) : Controller + ICreateOrganizationIntegrationCommand createCommand, + IUpdateOrganizationIntegrationCommand updateCommand, + IDeleteOrganizationIntegrationCommand deleteCommand, + IGetOrganizationIntegrationsQuery getQuery) : Controller { [HttpGet("")] public async Task> GetAsync(Guid organizationId) @@ -27,7 +25,7 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integrations = await getQuery.GetManyByOrganizationAsync(organizationId); return integrations .Select(integration => new OrganizationIntegrationResponseModel(integration)) .ToList(); @@ -41,8 +39,10 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.CreateAsync(model.ToOrganizationIntegration(organizationId)); - return new OrganizationIntegrationResponseModel(integration); + var integration = model.ToOrganizationIntegration(organizationId); + var created = await createCommand.CreateAsync(integration); + + return new OrganizationIntegrationResponseModel(created); } [HttpPut("{integrationId:guid}")] @@ -53,14 +53,10 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration is null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } + var integration = model.ToOrganizationIntegration(organizationId); + var updated = await updateCommand.UpdateAsync(organizationId, integrationId, integration); - await integrationRepository.ReplaceAsync(model.ToOrganizationIntegration(integration)); - return new OrganizationIntegrationResponseModel(integration); + return new OrganizationIntegrationResponseModel(updated); } [HttpDelete("{integrationId:guid}")] @@ -71,13 +67,7 @@ public class OrganizationIntegrationController( throw new NotFoundException(); } - var integration = await integrationRepository.GetByIdAsync(integrationId); - if (integration is null || integration.OrganizationId != organizationId) - { - throw new NotFoundException(); - } - - await integrationRepository.DeleteAsync(integration); + await deleteCommand.DeleteAsync(organizationId, integrationId); } [HttpPost("{integrationId:guid}/delete")] diff --git a/src/Api/Dirt/Controllers/OrganizationReportsController.cs b/src/Api/Dirt/Controllers/OrganizationReportsController.cs index bcd64b0bdf..fc9a1b2d84 100644 --- a/src/Api/Dirt/Controllers/OrganizationReportsController.cs +++ b/src/Api/Dirt/Controllers/OrganizationReportsController.cs @@ -1,4 +1,5 @@ -using Bit.Core.Context; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Exceptions; @@ -61,8 +62,9 @@ public class OrganizationReportsController : Controller } var latestReport = await _getOrganizationReportQuery.GetLatestOrganizationReportAsync(organizationId); + var response = latestReport == null ? null : new OrganizationReportResponseModel(latestReport); - return Ok(latestReport); + return Ok(response); } [HttpGet("{organizationId}/{reportId}")] @@ -102,7 +104,8 @@ public class OrganizationReportsController : Controller } var report = await _addOrganizationReportCommand.AddOrganizationReportAsync(request); - return Ok(report); + var response = report == null ? null : new OrganizationReportResponseModel(report); + return Ok(response); } [HttpPatch("{organizationId}/{reportId}")] @@ -119,7 +122,8 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportCommand.UpdateOrganizationReportAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + return Ok(response); } #endregion @@ -182,10 +186,10 @@ public class OrganizationReportsController : Controller { throw new BadRequestException("Report ID in the request body must match the route parameter"); } - var updatedReport = await _updateOrganizationReportSummaryCommand.UpdateOrganizationReportSummaryAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - return Ok(updatedReport); + return Ok(response); } #endregion @@ -228,7 +232,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportDataCommand.UpdateOrganizationReportDataAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + + return Ok(response); } #endregion @@ -265,7 +271,6 @@ public class OrganizationReportsController : Controller { try { - if (!await _currentContext.AccessReports(organizationId)) { throw new NotFoundException(); @@ -282,10 +287,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportApplicationDataCommand.UpdateOrganizationReportApplicationDataAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - - - return Ok(updatedReport); + return Ok(response); } catch (Exception ex) when (!(ex is BadRequestException || ex is NotFoundException)) { diff --git a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs b/src/Api/Dirt/Controllers/SlackIntegrationController.cs similarity index 88% rename from src/Api/AdminConsole/Controllers/SlackIntegrationController.cs rename to src/Api/Dirt/Controllers/SlackIntegrationController.cs index c8ff4f9f7c..e98ed0d3fa 100644 --- a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs +++ b/src/Api/Dirt/Controllers/SlackIntegrationController.cs @@ -1,20 +1,17 @@ using System.Text.Json; -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Api.Dirt.Models.Response; using Bit.Core.Context; -using Bit.Core.Enums; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.AdminConsole.Controllers; +namespace Bit.Api.Dirt.Controllers; -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] [Route("organizations")] [Authorize("Application")] public class SlackIntegrationController( @@ -32,7 +29,7 @@ public class SlackIntegrationController( } string? callbackUrl = Url.RouteUrl( - routeName: nameof(CreateAsync), + routeName: "SlackIntegration_Create", values: null, protocol: currentContext.HttpContext.Request.Scheme, host: currentContext.HttpContext.Request.Host.ToUriComponent() @@ -76,7 +73,7 @@ public class SlackIntegrationController( return Redirect(redirectUrl); } - [HttpGet("integrations/slack/create", Name = nameof(CreateAsync))] + [HttpGet("integrations/slack/create", Name = "SlackIntegration_Create")] [AllowAnonymous] public async Task CreateAsync([FromQuery] string code, [FromQuery] string state) { @@ -103,7 +100,7 @@ public class SlackIntegrationController( // Fetch token from Slack and store to DB string? callbackUrl = Url.RouteUrl( - routeName: nameof(CreateAsync), + routeName: "SlackIntegration_Create", values: null, protocol: currentContext.HttpContext.Request.Scheme, host: currentContext.HttpContext.Request.Host.ToUriComponent() diff --git a/src/Api/Dirt/Controllers/TeamsIntegrationController.cs b/src/Api/Dirt/Controllers/TeamsIntegrationController.cs new file mode 100644 index 0000000000..b2bd55017c --- /dev/null +++ b/src/Api/Dirt/Controllers/TeamsIntegrationController.cs @@ -0,0 +1,144 @@ +using System.Text.Json; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; + +namespace Bit.Api.Dirt.Controllers; + +[Route("organizations")] +[Authorize("Application")] +public class TeamsIntegrationController( + ICurrentContext currentContext, + IOrganizationIntegrationRepository integrationRepository, + IBot bot, + IBotFrameworkHttpAdapter adapter, + ITeamsService teamsService, + TimeProvider timeProvider) : Controller +{ + [HttpGet("{organizationId:guid}/integrations/teams/redirect")] + public async Task RedirectAsync(Guid organizationId) + { + if (!await currentContext.OrganizationOwner(organizationId)) + { + throw new NotFoundException(); + } + + var callbackUrl = Url.RouteUrl( + routeName: "TeamsIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integration = integrations.FirstOrDefault(i => i.Type == IntegrationType.Teams); + + if (integration is null) + { + // No teams integration exists, create Initiated version + integration = await integrationRepository.CreateAsync(new OrganizationIntegration + { + OrganizationId = organizationId, + Type = IntegrationType.Teams, + Configuration = null, + }); + } + else if (integration.Configuration is not null) + { + // A Completed (fully configured) Teams integration already exists, throw to prevent overriding + throw new BadRequestException("There already exists a Teams integration for this organization"); + + } // An Initiated teams integration exits, re-use it and kick off a new OAuth flow + + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + var redirectUrl = teamsService.GetRedirectUrl( + callbackUrl: callbackUrl, + state: state.ToString() + ); + + if (string.IsNullOrEmpty(redirectUrl)) + { + throw new NotFoundException(); + } + + return Redirect(redirectUrl); + } + + [HttpGet("integrations/teams/create", Name = "TeamsIntegration_Create")] + [AllowAnonymous] + public async Task CreateAsync([FromQuery] string code, [FromQuery] string state) + { + var oAuthState = IntegrationOAuthState.FromString(state: state, timeProvider: timeProvider); + if (oAuthState is null) + { + throw new NotFoundException(); + } + + // Fetch existing Initiated record + var integration = await integrationRepository.GetByIdAsync(oAuthState.IntegrationId); + if (integration is null || + integration.Type != IntegrationType.Teams || + integration.Configuration is not null) + { + throw new NotFoundException(); + } + + // Verify Organization matches hash + if (!oAuthState.ValidateOrg(integration.OrganizationId)) + { + throw new NotFoundException(); + } + + var callbackUrl = Url.RouteUrl( + routeName: "TeamsIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var token = await teamsService.ObtainTokenViaOAuth(code, callbackUrl); + if (string.IsNullOrEmpty(token)) + { + throw new BadRequestException("Invalid response from Teams."); + } + + var teams = await teamsService.GetJoinedTeamsAsync(token); + + if (!teams.Any()) + { + throw new BadRequestException("No teams were found."); + } + + var teamsIntegration = new TeamsIntegration(TenantId: teams[0].TenantId, Teams: teams); + integration.Configuration = JsonSerializer.Serialize(teamsIntegration); + await integrationRepository.UpsertAsync(integration); + + var location = $"/organizations/{integration.OrganizationId}/integrations/{integration.Id}"; + return Created(location, new OrganizationIntegrationResponseModel(integration)); + } + + [Route("integrations/teams/incoming")] + [AllowAnonymous] + [HttpPost] + public async Task IncomingPostAsync() + { + await adapter.ProcessAsync(Request, Response, bot); + } +} diff --git a/src/Api/Dirt/Models/Request/OrganizationIntegrationConfigurationRequestModel.cs b/src/Api/Dirt/Models/Request/OrganizationIntegrationConfigurationRequestModel.cs new file mode 100644 index 0000000000..e918bea2d6 --- /dev/null +++ b/src/Api/Dirt/Models/Request/OrganizationIntegrationConfigurationRequestModel.cs @@ -0,0 +1,27 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Enums; + +namespace Bit.Api.Dirt.Models.Request; + +public class OrganizationIntegrationConfigurationRequestModel +{ + public string? Configuration { get; set; } + + public EventType? EventType { get; set; } + + public string? Filters { get; set; } + + public string? Template { get; set; } + + public OrganizationIntegrationConfiguration ToOrganizationIntegrationConfiguration(Guid organizationIntegrationId) + { + return new OrganizationIntegrationConfiguration() + { + OrganizationIntegrationId = organizationIntegrationId, + Configuration = Configuration, + Filters = Filters, + EventType = EventType, + Template = Template + }; + } +} diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs b/src/Api/Dirt/Models/Request/OrganizationIntegrationRequestModel.cs similarity index 92% rename from src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs rename to src/Api/Dirt/Models/Request/OrganizationIntegrationRequestModel.cs index 92d65ab8fe..259671bd66 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs +++ b/src/Api/Dirt/Models/Request/OrganizationIntegrationRequestModel.cs @@ -1,10 +1,10 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; -namespace Bit.Api.AdminConsole.Models.Request.Organizations; +namespace Bit.Api.Dirt.Models.Request; public class OrganizationIntegrationRequestModel : IValidatableObject { @@ -35,7 +35,7 @@ public class OrganizationIntegrationRequestModel : IValidatableObject case IntegrationType.CloudBillingSync or IntegrationType.Scim: yield return new ValidationResult($"{nameof(Type)} integrations are not yet supported.", [nameof(Type)]); break; - case IntegrationType.Slack: + case IntegrationType.Slack or IntegrationType.Teams: yield return new ValidationResult($"{nameof(Type)} integrations cannot be created directly.", [nameof(Type)]); break; case IntegrationType.Webhook: diff --git a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs b/src/Api/Dirt/Models/Response/EventResponseModel.cs similarity index 98% rename from src/Api/AdminConsole/Models/Response/EventResponseModel.cs rename to src/Api/Dirt/Models/Response/EventResponseModel.cs index c259bc3bc4..bfcc50c84e 100644 --- a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs +++ b/src/Api/Dirt/Models/Response/EventResponseModel.cs @@ -2,7 +2,7 @@ using Bit.Core.Models.Api; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Response; +namespace Bit.Api.Dirt.Models.Response; public class EventResponseModel : ResponseModel { diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs b/src/Api/Dirt/Models/Response/OrganizationIntegrationConfigurationResponseModel.cs similarity index 83% rename from src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs rename to src/Api/Dirt/Models/Response/OrganizationIntegrationConfigurationResponseModel.cs index c7906318e8..62a3aea405 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs +++ b/src/Api/Dirt/Models/Response/OrganizationIntegrationConfigurationResponseModel.cs @@ -1,18 +1,14 @@ -using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Entities; using Bit.Core.Enums; using Bit.Core.Models.Api; -#nullable enable - -namespace Bit.Api.AdminConsole.Models.Response.Organizations; +namespace Bit.Api.Dirt.Models.Response; public class OrganizationIntegrationConfigurationResponseModel : ResponseModel { public OrganizationIntegrationConfigurationResponseModel(OrganizationIntegrationConfiguration organizationIntegrationConfiguration, string obj = "organizationIntegrationConfiguration") : base(obj) { - ArgumentNullException.ThrowIfNull(organizationIntegrationConfiguration); - Id = organizationIntegrationConfiguration.Id; Configuration = organizationIntegrationConfiguration.Configuration; CreationDate = organizationIntegrationConfiguration.CreationDate; diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs b/src/Api/Dirt/Models/Response/OrganizationIntegrationResponseModel.cs similarity index 68% rename from src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs rename to src/Api/Dirt/Models/Response/OrganizationIntegrationResponseModel.cs index 5368f78e39..60e885fe82 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs +++ b/src/Api/Dirt/Models/Response/OrganizationIntegrationResponseModel.cs @@ -1,8 +1,10 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Enums; +using System.Text.Json; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Models.Api; -namespace Bit.Api.AdminConsole.Models.Response.Organizations; +namespace Bit.Api.Dirt.Models.Response; public class OrganizationIntegrationResponseModel : ResponseModel { @@ -35,6 +37,16 @@ public class OrganizationIntegrationResponseModel : ResponseModel ? OrganizationIntegrationStatus.Initiated : OrganizationIntegrationStatus.Completed, + // If present and the configuration is null, OAuth has been initiated, and we are + // waiting on the return OAuth call. If Configuration is not null and IsCompleted is true, + // then we've received the app install bot callback, and it's Completed. Otherwise, + // it is In Progress while we await the app install bot callback. + IntegrationType.Teams => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Initiated + : (JsonSerializer.Deserialize(Configuration)?.IsCompleted ?? false) + ? OrganizationIntegrationStatus.Completed + : OrganizationIntegrationStatus.InProgress, + // HEC and Datadog should only be allowed to be created non-null. // If they are null, they are Invalid IntegrationType.Hec => string.IsNullOrWhiteSpace(Configuration) diff --git a/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs new file mode 100644 index 0000000000..e477e5b806 --- /dev/null +++ b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs @@ -0,0 +1,38 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Api.Dirt.Models.Response; + +public class OrganizationReportResponseModel +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public string? ReportData { get; set; } + public string? ContentEncryptionKey { get; set; } + public string? SummaryData { get; set; } + public string? ApplicationData { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public DateTime? CreationDate { get; set; } = null; + public DateTime? RevisionDate { get; set; } = null; + + public OrganizationReportResponseModel(OrganizationReport organizationReport) + { + if (organizationReport == null) + { + return; + } + + Id = organizationReport.Id; + OrganizationId = organizationReport.OrganizationId; + ReportData = organizationReport.ReportData; + ContentEncryptionKey = organizationReport.ContentEncryptionKey; + SummaryData = organizationReport.SummaryData; + ApplicationData = organizationReport.ApplicationData; + PasswordCount = organizationReport.PasswordCount; + PasswordAtRiskCount = organizationReport.PasswordAtRiskCount; + MemberCount = organizationReport.MemberCount; + CreationDate = organizationReport.CreationDate; + RevisionDate = organizationReport.RevisionDate; + } +} diff --git a/src/Api/Dirt/Public/Controllers/EventsController.cs b/src/Api/Dirt/Public/Controllers/EventsController.cs new file mode 100644 index 0000000000..8c76137489 --- /dev/null +++ b/src/Api/Dirt/Public/Controllers/EventsController.cs @@ -0,0 +1,132 @@ +using System.Net; +using Bit.Api.Dirt.Public.Models; +using Bit.Api.Models.Public.Response; +using Bit.Api.Utilities.DiagnosticTools; +using Bit.Core.Context; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; +using Bit.Core.Vault.Repositories; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Dirt.Public.Controllers; + +[Route("public/events")] +[Authorize("Organization")] +public class EventsController : Controller +{ + private readonly IEventRepository _eventRepository; + private readonly ICipherRepository _cipherRepository; + private readonly ICurrentContext _currentContext; + private readonly ISecretRepository _secretRepository; + private readonly IProjectRepository _projectRepository; + private readonly IUserService _userService; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; + + public EventsController( + IEventRepository eventRepository, + ICipherRepository cipherRepository, + ICurrentContext currentContext, + ISecretRepository secretRepository, + IProjectRepository projectRepository, + IUserService userService, + ILogger logger, + IFeatureService featureService) + { + _eventRepository = eventRepository; + _cipherRepository = cipherRepository; + _currentContext = currentContext; + _secretRepository = secretRepository; + _projectRepository = projectRepository; + _userService = userService; + _logger = logger; + _featureService = featureService; + } + + /// + /// List all events. + /// + /// + /// Returns a filtered list of your organization's event logs, paged by a continuation token. + /// If no filters are provided, it will return the last 30 days of event for the organization. + /// + [HttpGet] + [ProducesResponseType(typeof(PagedListResponseModel), (int)HttpStatusCode.OK)] + public async Task List([FromQuery] EventFilterRequestModel request) + { + if (!_currentContext.OrganizationId.HasValue) + { + return new JsonResult(new PagedListResponseModel([], "")); + } + + var organizationId = _currentContext.OrganizationId.Value; + var dateRange = request.ToDateRange(); + var result = new PagedResult(); + if (request.ActingUserId.HasValue) + { + result = await _eventRepository.GetManyByOrganizationActingUserAsync( + organizationId, request.ActingUserId.Value, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else if (request.ItemId.HasValue) + { + var cipher = await _cipherRepository.GetByIdAsync(request.ItemId.Value); + if (cipher != null && cipher.OrganizationId == organizationId) + { + result = await _eventRepository.GetManyByCipherAsync( + cipher, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + } + else if (request.SecretId.HasValue) + { + var secret = await _secretRepository.GetByIdAsync(request.SecretId.Value); + + if (secret == null) + { + secret = new Core.SecretsManager.Entities.Secret { Id = request.SecretId.Value, OrganizationId = organizationId }; + } + + if (secret.OrganizationId == organizationId) + { + result = await _eventRepository.GetManyBySecretAsync( + secret, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else + { + return new JsonResult(new PagedListResponseModel([], "")); + } + } + else if (request.ProjectId.HasValue) + { + var project = await _projectRepository.GetByIdAsync(request.ProjectId.Value); + if (project != null && project.OrganizationId == organizationId) + { + result = await _eventRepository.GetManyByProjectAsync( + project, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + else + { + return new JsonResult(new PagedListResponseModel([], "")); + } + } + else + { + result = await _eventRepository.GetManyByOrganizationAsync( + organizationId, dateRange.Item1, dateRange.Item2, + new PageOptions { ContinuationToken = request.ContinuationToken }); + } + + var eventResponses = result.Data.Select(e => new EventResponseModel(e)); + var response = new PagedListResponseModel(eventResponses, result.ContinuationToken ?? ""); + + _logger.LogAggregateData(_featureService, organizationId, response, request); + + return new JsonResult(response); + } +} diff --git a/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs b/src/Api/Dirt/Public/Models/EventFilterRequestModel.cs similarity index 81% rename from src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs rename to src/Api/Dirt/Public/Models/EventFilterRequestModel.cs index 2d96425d55..20984c2cb0 100644 --- a/src/Api/AdminConsole/Public/Models/Request/EventFilterRequestModel.cs +++ b/src/Api/Dirt/Public/Models/EventFilterRequestModel.cs @@ -3,7 +3,7 @@ using Bit.Core.Exceptions; -namespace Bit.Api.Models.Public.Request; +namespace Bit.Api.Dirt.Public.Models; public class EventFilterRequestModel { @@ -24,6 +24,14 @@ public class EventFilterRequestModel /// public Guid? ItemId { get; set; } /// + /// The unique identifier of the related secret that the event describes. + /// + public Guid? SecretId { get; set; } + /// + /// The unique identifier of the related project that the event describes. + /// + public Guid? ProjectId { get; set; } + /// /// A cursor for use in pagination. /// public string ContinuationToken { get; set; } diff --git a/src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs b/src/Api/Dirt/Public/Models/EventResponseModel.cs similarity index 98% rename from src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs rename to src/Api/Dirt/Public/Models/EventResponseModel.cs index 3e1de2747a..77c0b5a275 100644 --- a/src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs +++ b/src/Api/Dirt/Public/Models/EventResponseModel.cs @@ -1,8 +1,9 @@ using System.ComponentModel.DataAnnotations; +using Bit.Api.Models.Public.Response; using Bit.Core.Enums; using Bit.Core.Models.Data; -namespace Bit.Api.Models.Public.Response; +namespace Bit.Api.Dirt.Public.Models; /// /// An event log. diff --git a/src/Api/Jobs/JobsHostedService.cs b/src/Api/Jobs/JobsHostedService.cs index 0178f6d68b..a9626dc90e 100644 --- a/src/Api/Jobs/JobsHostedService.cs +++ b/src/Api/Jobs/JobsHostedService.cs @@ -59,13 +59,6 @@ public class JobsHostedService : BaseJobsHostedService .StartNow() .WithCronSchedule("0 0 * * * ?") .Build(); - var updatePhishingDomainsTrigger = TriggerBuilder.Create() - .WithIdentity("UpdatePhishingDomainsTrigger") - .StartNow() - .WithSimpleSchedule(x => x - .WithIntervalInHours(24) - .RepeatForever()) - .Build(); var updateOrgSubscriptionsTrigger = TriggerBuilder.Create() .WithIdentity("UpdateOrgSubscriptionsTrigger") .StartNow() @@ -81,7 +74,6 @@ public class JobsHostedService : BaseJobsHostedService new Tuple(typeof(ValidateUsersJob), everyTopOfTheSixthHourTrigger), new Tuple(typeof(ValidateOrganizationsJob), everyTwelfthHourAndThirtyMinutesTrigger), new Tuple(typeof(ValidateOrganizationDomainJob), validateOrganizationDomainTrigger), - new Tuple(typeof(UpdatePhishingDomainsJob), updatePhishingDomainsTrigger), new (typeof(OrganizationSubscriptionUpdateJob), updateOrgSubscriptionsTrigger), }; @@ -111,7 +103,6 @@ public class JobsHostedService : BaseJobsHostedService services.AddTransient(); services.AddTransient(); services.AddTransient(); - services.AddTransient(); services.AddTransient(); } diff --git a/src/Api/Jobs/UpdatePhishingDomainsJob.cs b/src/Api/Jobs/UpdatePhishingDomainsJob.cs deleted file mode 100644 index 355f2af69b..0000000000 --- a/src/Api/Jobs/UpdatePhishingDomainsJob.cs +++ /dev/null @@ -1,97 +0,0 @@ -using Bit.Core; -using Bit.Core.Jobs; -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Settings; -using Quartz; - -namespace Bit.Api.Jobs; - -public class UpdatePhishingDomainsJob : BaseJob -{ - private readonly GlobalSettings _globalSettings; - private readonly IPhishingDomainRepository _phishingDomainRepository; - private readonly ICloudPhishingDomainQuery _cloudPhishingDomainQuery; - private readonly IFeatureService _featureService; - public UpdatePhishingDomainsJob( - GlobalSettings globalSettings, - IPhishingDomainRepository phishingDomainRepository, - ICloudPhishingDomainQuery cloudPhishingDomainQuery, - IFeatureService featureService, - ILogger logger) - : base(logger) - { - _globalSettings = globalSettings; - _phishingDomainRepository = phishingDomainRepository; - _cloudPhishingDomainQuery = cloudPhishingDomainQuery; - _featureService = featureService; - } - - protected override async Task ExecuteJobAsync(IJobExecutionContext context) - { - if (!_featureService.IsEnabled(FeatureFlagKeys.PhishingDetection)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Feature flag is disabled."); - return; - } - - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. No URL configured."); - return; - } - - if (_globalSettings.SelfHosted && !_globalSettings.EnableCloudCommunication) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Skipping phishing domain update. Cloud communication is disabled in global settings."); - return; - } - - var remoteChecksum = await _cloudPhishingDomainQuery.GetRemoteChecksumAsync(); - if (string.IsNullOrWhiteSpace(remoteChecksum)) - { - _logger.LogWarning(Constants.BypassFiltersEventId, "Could not retrieve remote checksum. Skipping update."); - return; - } - - var currentChecksum = await _phishingDomainRepository.GetCurrentChecksumAsync(); - - if (string.Equals(currentChecksum, remoteChecksum, StringComparison.OrdinalIgnoreCase)) - { - _logger.LogInformation(Constants.BypassFiltersEventId, - "Phishing domains list is up to date (checksum: {Checksum}). Skipping update.", - currentChecksum); - return; - } - - _logger.LogInformation(Constants.BypassFiltersEventId, - "Checksums differ (current: {CurrentChecksum}, remote: {RemoteChecksum}). Fetching updated domains from {Source}.", - currentChecksum, remoteChecksum, _globalSettings.SelfHosted ? "Bitwarden cloud API" : "external source"); - - try - { - var domains = await _cloudPhishingDomainQuery.GetPhishingDomainsAsync(); - if (!domains.Contains("phishing.testcategory.com", StringComparer.OrdinalIgnoreCase)) - { - domains.Add("phishing.testcategory.com"); - } - - if (domains.Count > 0) - { - _logger.LogInformation(Constants.BypassFiltersEventId, "Updating {Count} phishing domains with checksum {Checksum}.", - domains.Count, remoteChecksum); - await _phishingDomainRepository.UpdatePhishingDomainsAsync(domains, remoteChecksum); - _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated phishing domains."); - } - else - { - _logger.LogWarning(Constants.BypassFiltersEventId, "No valid domains found in the response. Skipping update."); - } - } - catch (Exception ex) - { - _logger.LogError(Constants.BypassFiltersEventId, ex, "Error updating phishing domains."); - } - } -} diff --git a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs index 9fc0e9a75a..a124616e30 100644 --- a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs +++ b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs @@ -1,8 +1,8 @@ -#nullable enable -using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.WebAuthn; using Bit.Api.KeyManagement.Models.Requests; +using Bit.Api.KeyManagement.Models.Responses; using Bit.Api.KeyManagement.Validators; using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; @@ -14,6 +14,7 @@ using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Commands.Interfaces; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Services; @@ -45,11 +46,14 @@ public class AccountsKeyManagementController : Controller private readonly IRotationValidator, IEnumerable> _webauthnKeyValidator; private readonly IRotationValidator, IEnumerable> _deviceValidator; + private readonly IKeyConnectorConfirmationDetailsQuery _keyConnectorConfirmationDetailsQuery; + private readonly ISetKeyConnectorKeyCommand _setKeyConnectorKeyCommand; public AccountsKeyManagementController(IUserService userService, IFeatureService featureService, IOrganizationUserRepository organizationUserRepository, IEmergencyAccessRepository emergencyAccessRepository, + IKeyConnectorConfirmationDetailsQuery keyConnectorConfirmationDetailsQuery, IRegenerateUserAsymmetricKeysCommand regenerateUserAsymmetricKeysCommand, IRotateUserAccountKeysCommand rotateUserKeyCommandV2, IRotationValidator, IEnumerable> cipherValidator, @@ -59,8 +63,10 @@ public class AccountsKeyManagementController : Controller emergencyAccessValidator, IRotationValidator, IReadOnlyList> organizationUserValidator, - IRotationValidator, IEnumerable> webAuthnKeyValidator, - IRotationValidator, IEnumerable> deviceValidator) + IRotationValidator, IEnumerable> + webAuthnKeyValidator, + IRotationValidator, IEnumerable> deviceValidator, + ISetKeyConnectorKeyCommand setKeyConnectorKeyCommand) { _userService = userService; _featureService = featureService; @@ -75,12 +81,14 @@ public class AccountsKeyManagementController : Controller _organizationUserValidator = organizationUserValidator; _webauthnKeyValidator = webAuthnKeyValidator; _deviceValidator = deviceValidator; + _keyConnectorConfirmationDetailsQuery = keyConnectorConfirmationDetailsQuery; + _setKeyConnectorKeyCommand = setKeyConnectorKeyCommand; } [HttpPost("key-management/regenerate-keys")] public async Task RegenerateKeysAsync([FromBody] KeyRegenerationRequestModel request) { - if (!_featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration)) + if (!_featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration) && !_featureService.IsEnabled(FeatureFlagKeys.DataRecoveryTool)) { throw new NotFoundException(); } @@ -106,8 +114,7 @@ public class AccountsKeyManagementController : Controller { OldMasterKeyAuthenticationHash = model.OldMasterKeyAuthenticationHash, - UserKeyEncryptedAccountPrivateKey = model.AccountKeys.UserKeyEncryptedAccountPrivateKey, - AccountPublicKey = model.AccountKeys.AccountPublicKey, + AccountKeys = model.AccountKeys.ToAccountKeysData(), MasterPasswordUnlockData = model.AccountUnlockData.MasterPasswordUnlockData.ToUnlockData(), EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData), @@ -143,18 +150,28 @@ public class AccountsKeyManagementController : Controller throw new UnauthorizedAccessException(); } - var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); - if (result.Succeeded) + if (model.IsV2Request()) { - return; + // V2 account registration + await _setKeyConnectorKeyCommand.SetKeyConnectorKeyForUserAsync(user, model.ToKeyConnectorKeysData()); } - - foreach (var error in result.Errors) + else { - ModelState.AddModelError(string.Empty, error.Description); - } + // V1 account registration + // TODO removed with https://bitwarden.atlassian.net/browse/PM-27328 + var result = await _userService.SetKeyConnectorKeyAsync(model.ToUser(user), model.Key, model.OrgIdentifier); + if (result.Succeeded) + { + return; + } - throw new BadRequestException(ModelState); + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + throw new BadRequestException(ModelState); + } } [HttpPost("convert-to-key-connector")] @@ -179,4 +196,17 @@ public class AccountsKeyManagementController : Controller throw new BadRequestException(ModelState); } + + [HttpGet("key-connector/confirmation-details/{orgSsoIdentifier}")] + public async Task GetKeyConnectorConfirmationDetailsAsync(string orgSsoIdentifier) + { + var user = await _userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var details = await _keyConnectorConfirmationDetailsQuery.Run(orgSsoIdentifier, user.Id); + return new KeyConnectorConfirmationDetailsResponseModel(details); + } } diff --git a/src/Api/KeyManagement/Controllers/UsersController.cs b/src/Api/KeyManagement/Controllers/UsersController.cs new file mode 100644 index 0000000000..cfd2f8ee29 --- /dev/null +++ b/src/Api/KeyManagement/Controllers/UsersController.cs @@ -0,0 +1,39 @@ +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Repositories; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using UserKeyResponseModel = Bit.Api.Models.Response.UserKeyResponseModel; + + +namespace Bit.Api.KeyManagement.Controllers; + +[Route("users")] +[Authorize("Application")] +public class UsersController : Controller +{ + private readonly IUserRepository _userRepository; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; + + public UsersController(IUserRepository userRepository, IUserAccountKeysQuery userAccountKeysQuery) + { + _userRepository = userRepository; + _userAccountKeysQuery = userAccountKeysQuery; + } + + [HttpGet("{id}/public-key")] + public async Task GetPublicKeyAsync([FromRoute] Guid id) + { + var key = await _userRepository.GetPublicKeyAsync(id) ?? throw new NotFoundException(); + return new UserKeyResponseModel(id, key); + } + + [HttpGet("{id}/keys")] + public async Task GetAccountKeysAsync([FromRoute] Guid id) + { + var user = await _userRepository.GetByIdAsync(id) ?? throw new NotFoundException(); + var accountKeys = await _userAccountKeysQuery.Run(user) ?? throw new NotFoundException("User account keys not found."); + return new PublicKeysResponseModel(accountKeys); + } +} diff --git a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs b/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs deleted file mode 100644 index 7c7de4d210..0000000000 --- a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs +++ /dev/null @@ -1,10 +0,0 @@ -#nullable enable -using Bit.Core.Utilities; - -namespace Bit.Api.KeyManagement.Models.Requests; - -public class AccountKeysRequestModel -{ - [EncryptedString] public required string UserKeyEncryptedAccountPrivateKey { get; set; } - public required string AccountPublicKey { get; set; } -} diff --git a/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs b/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs index 495d13cccd..767cfd3f9b 100644 --- a/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs index b0b19e2bd3..3510be9546 100644 --- a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs @@ -1,5 +1,5 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; +using Bit.Core.KeyManagement.Models.Api.Request; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Requests/SetKeyConnectorKeyRequestModel.cs b/src/Api/KeyManagement/Models/Requests/SetKeyConnectorKeyRequestModel.cs index 9f52a97383..6cd13fdf83 100644 --- a/src/Api/KeyManagement/Models/Requests/SetKeyConnectorKeyRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/SetKeyConnectorKeyRequestModel.cs @@ -1,36 +1,112 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; namespace Bit.Api.KeyManagement.Models.Requests; -public class SetKeyConnectorKeyRequestModel +public class SetKeyConnectorKeyRequestModel : IValidatableObject { - [Required] - public string Key { get; set; } - [Required] - public KeysRequestModel Keys { get; set; } - [Required] - public KdfType Kdf { get; set; } - [Required] - public int KdfIterations { get; set; } - public int? KdfMemory { get; set; } - public int? KdfParallelism { get; set; } - [Required] - public string OrgIdentifier { get; set; } + // TODO will be removed with https://bitwarden.atlassian.net/browse/PM-27328 + [Obsolete("Use KeyConnectorKeyWrappedUserKey instead")] + public string? Key { get; set; } + [Obsolete("Use AccountKeys instead")] + public KeysRequestModel? Keys { get; set; } + [Obsolete("Not used anymore")] + public KdfType? Kdf { get; set; } + [Obsolete("Not used anymore")] + public int? KdfIterations { get; set; } + [Obsolete("Not used anymore")] + public int? KdfMemory { get; set; } + [Obsolete("Not used anymore")] + public int? KdfParallelism { get; set; } + + [EncryptedString] + public string? KeyConnectorKeyWrappedUserKey { get; set; } + public AccountKeysRequestModel? AccountKeys { get; set; } + + [Required] + public required string OrgIdentifier { get; init; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (IsV2Request()) + { + // V2 registration + yield break; + } + + // V1 registration + // TODO removed with https://bitwarden.atlassian.net/browse/PM-27328 + if (string.IsNullOrEmpty(Key)) + { + yield return new ValidationResult("Key must be supplied."); + } + + if (Keys == null) + { + yield return new ValidationResult("Keys must be supplied."); + } + + if (Kdf == null) + { + yield return new ValidationResult("Kdf must be supplied."); + } + + if (KdfIterations == null) + { + yield return new ValidationResult("KdfIterations must be supplied."); + } + + if (Kdf == KdfType.Argon2id) + { + if (KdfMemory == null) + { + yield return new ValidationResult("KdfMemory must be supplied when Kdf is Argon2id."); + } + + if (KdfParallelism == null) + { + yield return new ValidationResult("KdfParallelism must be supplied when Kdf is Argon2id."); + } + } + } + + public bool IsV2Request() + { + return !string.IsNullOrEmpty(KeyConnectorKeyWrappedUserKey) && AccountKeys != null; + } + + // TODO removed with https://bitwarden.atlassian.net/browse/PM-27328 public User ToUser(User existingUser) { - existingUser.Kdf = Kdf; - existingUser.KdfIterations = KdfIterations; + existingUser.Kdf = Kdf!.Value; + existingUser.KdfIterations = KdfIterations!.Value; existingUser.KdfMemory = KdfMemory; existingUser.KdfParallelism = KdfParallelism; existingUser.Key = Key; - Keys.ToUser(existingUser); + Keys!.ToUser(existingUser); return existingUser; } + + public KeyConnectorKeysData ToKeyConnectorKeysData() + { + // TODO remove validation with https://bitwarden.atlassian.net/browse/PM-27328 + if (string.IsNullOrEmpty(KeyConnectorKeyWrappedUserKey) || AccountKeys == null) + { + throw new BadRequestException("KeyConnectorKeyWrappedUserKey and AccountKeys must be supplied."); + } + + return new KeyConnectorKeysData + { + KeyConnectorKeyWrappedUserKey = KeyConnectorKeyWrappedUserKey, + AccountKeys = AccountKeys, + OrgIdentifier = OrgIdentifier + }; + } } diff --git a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs index 3af944110c..01e5dd7017 100644 --- a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.WebAuthn; diff --git a/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs index f854d82bcc..df922fcda0 100644 --- a/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Tools.Models.Request; +using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs b/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs new file mode 100644 index 0000000000..68d2c689df --- /dev/null +++ b/src/Api/KeyManagement/Models/Responses/KeyConnectorConfirmationDetailsResponseModel.cs @@ -0,0 +1,24 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Api.KeyManagement.Models.Responses; + +public class KeyConnectorConfirmationDetailsResponseModel : ResponseModel +{ + private const string _objectName = "keyConnectorConfirmationDetails"; + + public KeyConnectorConfirmationDetailsResponseModel(KeyConnectorConfirmationDetails details, + string obj = _objectName) : base(obj) + { + ArgumentNullException.ThrowIfNull(details); + + OrganizationName = details.OrganizationName; + } + + public KeyConnectorConfirmationDetailsResponseModel() : base(_objectName) + { + OrganizationName = string.Empty; + } + + public string OrganizationName { get; set; } +} diff --git a/src/Api/Models/Public/Response/CollectionResponseModel.cs b/src/Api/Models/Public/Response/CollectionResponseModel.cs index 04ae565a27..9e830aeea8 100644 --- a/src/Api/Models/Public/Response/CollectionResponseModel.cs +++ b/src/Api/Models/Public/Response/CollectionResponseModel.cs @@ -2,6 +2,7 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Core.Entities; using Bit.Core.Models.Data; @@ -13,6 +14,12 @@ namespace Bit.Api.Models.Public.Response; /// public class CollectionResponseModel : CollectionBaseModel, IResponseModel { + [JsonConstructor] + public CollectionResponseModel() + { + + } + public CollectionResponseModel(Collection collection, IEnumerable groups) { if (collection == null) diff --git a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs b/src/Api/Models/Request/BitPayInvoiceRequestModel.cs deleted file mode 100644 index d27736d712..0000000000 --- a/src/Api/Models/Request/BitPayInvoiceRequestModel.cs +++ /dev/null @@ -1,73 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Core.Settings; - -namespace Bit.Api.Models.Request; - -public class BitPayInvoiceRequestModel : IValidatableObject -{ - public Guid? UserId { get; set; } - public Guid? OrganizationId { get; set; } - public Guid? ProviderId { get; set; } - public bool Credit { get; set; } - [Required] - public decimal? Amount { get; set; } - public string ReturnUrl { get; set; } - public string Name { get; set; } - public string Email { get; set; } - - public BitPayLight.Models.Invoice.Invoice ToBitpayInvoice(GlobalSettings globalSettings) - { - var inv = new BitPayLight.Models.Invoice.Invoice - { - Price = Convert.ToDouble(Amount.Value), - Currency = "USD", - RedirectUrl = ReturnUrl, - Buyer = new BitPayLight.Models.Invoice.Buyer - { - Email = Email, - Name = Name - }, - NotificationUrl = globalSettings.BitPay.NotificationUrl, - FullNotifications = true, - ExtendedNotifications = true - }; - - var posData = string.Empty; - if (UserId.HasValue) - { - posData = "userId:" + UserId.Value; - } - else if (OrganizationId.HasValue) - { - posData = "organizationId:" + OrganizationId.Value; - } - else if (ProviderId.HasValue) - { - posData = "providerId:" + ProviderId.Value; - } - - if (Credit) - { - posData += ",accountCredit:1"; - inv.ItemDesc = "Bitwarden Account Credit"; - } - else - { - inv.ItemDesc = "Bitwarden"; - } - - inv.PosData = posData; - return inv; - } - - public IEnumerable Validate(ValidationContext validationContext) - { - if (!UserId.HasValue && !OrganizationId.HasValue && !ProviderId.HasValue) - { - yield return new ValidationResult("User, Organization or Provider is required."); - } - } -} diff --git a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs deleted file mode 100644 index 71f6873800..0000000000 --- a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationVerifyBankRequestModel -{ - [Required] - [Range(1, 99)] - public int? Amount1 { get; set; } - [Required] - [Range(1, 99)] - public int? Amount2 { get; set; } -} diff --git a/src/Api/Models/Response/ConfigResponseModel.cs b/src/Api/Models/Response/ConfigResponseModel.cs index 20bc3f9e10..d748254206 100644 --- a/src/Api/Models/Response/ConfigResponseModel.cs +++ b/src/Api/Models/Response/ConfigResponseModel.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core; using Bit.Core.Enums; using Bit.Core.Models.Api; using Bit.Core.Services; @@ -45,7 +46,8 @@ public class ConfigResponseModel : ResponseModel Sso = globalSettings.BaseServiceUri.Sso }; FeatureStates = featureService.GetAll(); - Push = PushSettings.Build(globalSettings); + var webPushEnabled = FeatureStates.TryGetValue(FeatureFlagKeys.WebPush, out var webPushEnabledValue) ? (bool)webPushEnabledValue : false; + Push = PushSettings.Build(webPushEnabled, globalSettings); Settings = new ServerSettingsResponseModel { DisableUserRegistration = globalSettings.DisableUserRegistration @@ -74,9 +76,9 @@ public class PushSettings public PushTechnologyType PushTechnology { get; private init; } public string VapidPublicKey { get; private init; } - public static PushSettings Build(IGlobalSettings globalSettings) + public static PushSettings Build(bool webPushEnabled, IGlobalSettings globalSettings) { - var vapidPublicKey = globalSettings.WebPush.VapidPublicKey; + var vapidPublicKey = webPushEnabled ? globalSettings.WebPush.VapidPublicKey : null; var pushTechnology = vapidPublicKey != null ? PushTechnologyType.WebPush : PushTechnologyType.SignalR; return new() { diff --git a/src/Api/Models/Response/KeysResponseModel.cs b/src/Api/Models/Response/KeysResponseModel.cs index cfc1a6a0a1..4c877e0bfc 100644 --- a/src/Api/Models/Response/KeysResponseModel.cs +++ b/src/Api/Models/Response/KeysResponseModel.cs @@ -1,27 +1,32 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; namespace Bit.Api.Models.Response; public class KeysResponseModel : ResponseModel { - public KeysResponseModel(User user) + public KeysResponseModel(UserAccountKeysData accountKeys, string? masterKeyWrappedUserKey) : base("keys") { - if (user == null) + if (masterKeyWrappedUserKey != null) { - throw new ArgumentNullException(nameof(user)); + Key = masterKeyWrappedUserKey; } - Key = user.Key; - PublicKey = user.PublicKey; - PrivateKey = user.PrivateKey; + PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + PrivateKey = accountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + AccountKeys = new PrivateKeysResponseModel(accountKeys); } - public string Key { get; set; } + /// + /// The master key wrapped user key. The master key can either be a master-password master key or a + /// key-connector master key. + /// + public string? Key { get; set; } + [Obsolete("Use AccountKeys.PublicKeyEncryptionKeyPair.PublicKey instead")] public string PublicKey { get; set; } + [Obsolete("Use AccountKeys.PublicKeyEncryptionKeyPair.WrappedPrivateKey instead")] public string PrivateKey { get; set; } + public PrivateKeysResponseModel AccountKeys { get; set; } } diff --git a/src/Api/Models/Response/ProfileResponseModel.cs b/src/Api/Models/Response/ProfileResponseModel.cs index cbdfaf0f16..30ba05b6a6 100644 --- a/src/Api/Models/Response/ProfileResponseModel.cs +++ b/src/Api/Models/Response/ProfileResponseModel.cs @@ -5,6 +5,8 @@ using Bit.Api.AdminConsole.Models.Response; using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -13,6 +15,7 @@ namespace Bit.Api.Models.Response; public class ProfileResponseModel : ResponseModel { public ProfileResponseModel(User user, + UserAccountKeysData userAccountKeysData, IEnumerable organizationsUserDetails, IEnumerable providerUserDetails, IEnumerable providerUserOrganizationDetails, @@ -35,6 +38,7 @@ public class ProfileResponseModel : ResponseModel TwoFactorEnabled = twoFactorEnabled; Key = user.Key; PrivateKey = user.PrivateKey; + AccountKeys = userAccountKeysData != null ? new PrivateKeysResponseModel(userAccountKeysData) : null; SecurityStamp = user.SecurityStamp; ForcePasswordReset = user.ForcePasswordReset; UsesKeyConnector = user.UsesKeyConnector; @@ -60,7 +64,9 @@ public class ProfileResponseModel : ResponseModel public string Culture { get; set; } public bool TwoFactorEnabled { get; set; } public string Key { get; set; } + [Obsolete("Use AccountKeys instead.")] public string PrivateKey { get; set; } + public PrivateKeysResponseModel AccountKeys { get; set; } public string SecurityStamp { get; set; } public bool ForcePasswordReset { get; set; } public bool UsesKeyConnector { get; set; } diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index 7038bee2a7..32d12aa416 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -1,6 +1,7 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using System.Security.Claims; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Models.Business; using Bit.Core.Entities; using Bit.Core.Models.Api; @@ -11,7 +12,17 @@ namespace Bit.Api.Models.Response; public class SubscriptionResponseModel : ResponseModel { - public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) + + /// The user entity containing storage and premium subscription information + /// Subscription information retrieved from the payment provider (Stripe/Braintree) + /// The user's license containing expiration and feature entitlements + /// + /// Whether to include discount information in the response. + /// Set to true when the PM23341_Milestone_2 feature flag is enabled AND + /// you want to expose Milestone 2 discount information to the client. + /// The discount will only be included if it matches the specific Milestone 2 coupon ID. + /// + public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license, bool includeMilestone2Discount = false) : base("subscription") { Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; @@ -22,9 +33,54 @@ public class SubscriptionResponseModel : ResponseModel MaxStorageGb = user.MaxStorageGb; License = license; Expiration = License.Expires; + + // Only display the Milestone 2 subscription discount on the subscription page. + CustomerDiscount = ShouldIncludeMilestone2Discount(includeMilestone2Discount, subscription.CustomerDiscount) + ? new BillingCustomerDiscount(subscription.CustomerDiscount!) + : null; } - public SubscriptionResponseModel(User user, UserLicense license = null) + /// The user entity containing storage and premium subscription information + /// Subscription information retrieved from the payment provider (Stripe/Braintree) + /// The user's license containing expiration and feature entitlements + /// The claims principal containing cryptographically secure token claims + /// + /// Whether to include discount information in the response. + /// Set to true when the PM23341_Milestone_2 feature flag is enabled AND + /// you want to expose Milestone 2 discount information to the client. + /// The discount will only be included if it matches the specific Milestone 2 coupon ID. + /// + public SubscriptionResponseModel(User user, SubscriptionInfo? subscription, UserLicense license, ClaimsPrincipal? claimsPrincipal, bool includeMilestone2Discount = false) + : base("subscription") + { + Subscription = subscription?.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; + UpcomingInvoice = subscription?.UpcomingInvoice != null ? + new BillingSubscriptionUpcomingInvoice(subscription.UpcomingInvoice) : null; + StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; + StorageGb = user.Storage.HasValue ? Math.Round(user.Storage.Value / 1073741824D, 2) : 0; // 1 GB + MaxStorageGb = user.MaxStorageGb; + License = license; + + // CRITICAL: When a license has a Token (JWT), ALWAYS use the expiration from the token claim + // The token's expiration is cryptographically secured and cannot be tampered with + // The file's Expires property can be manually edited and should NOT be trusted for display + if (claimsPrincipal != null) + { + Expiration = claimsPrincipal.GetValue(UserLicenseConstants.Expires); + } + else + { + // No token - use the license file expiration (for older licenses without tokens) + Expiration = License.Expires; + } + + // Only display the Milestone 2 subscription discount on the subscription page. + CustomerDiscount = ShouldIncludeMilestone2Discount(includeMilestone2Discount, subscription?.CustomerDiscount) + ? new BillingCustomerDiscount(subscription!.CustomerDiscount!) + : null; + } + + public SubscriptionResponseModel(User user, UserLicense? license = null) : base("subscription") { StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; @@ -38,21 +94,109 @@ public class SubscriptionResponseModel : ResponseModel } } - public string StorageName { get; set; } + public string? StorageName { get; set; } public double? StorageGb { get; set; } public short? MaxStorageGb { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public BillingSubscription Subscription { get; set; } - public UserLicense License { get; set; } + public BillingSubscriptionUpcomingInvoice? UpcomingInvoice { get; set; } + public BillingSubscription? Subscription { get; set; } + /// + /// Customer discount information from Stripe for the Milestone 2 subscription discount. + /// Only includes the specific Milestone 2 coupon (cm3nHfO1) when it's a perpetual discount (no expiration). + /// This is for display purposes only and does not affect Stripe's automatic discount application. + /// Other discounts may still apply in Stripe billing but are not included in this response. + /// + /// Null when: + /// - The PM23341_Milestone_2 feature flag is disabled + /// - There is no active discount + /// - The discount coupon ID doesn't match the Milestone 2 coupon (cm3nHfO1) + /// - The instance is self-hosted + /// + /// + public BillingCustomerDiscount? CustomerDiscount { get; set; } + public UserLicense? License { get; set; } public DateTime? Expiration { get; set; } + + /// + /// Determines whether the Milestone 2 discount should be included in the response. + /// + /// Whether the feature flag is enabled and discount should be considered. + /// The customer discount from subscription info, if any. + /// True if the discount should be included; false otherwise. + private static bool ShouldIncludeMilestone2Discount( + bool includeMilestone2Discount, + SubscriptionInfo.BillingCustomerDiscount? customerDiscount) + { + return includeMilestone2Discount && + customerDiscount != null && + customerDiscount.Id == StripeConstants.CouponIDs.Milestone2SubscriptionDiscount && + customerDiscount.Active; + } } -public class BillingCustomerDiscount(SubscriptionInfo.BillingCustomerDiscount discount) +/// +/// Customer discount information from Stripe billing. +/// +public class BillingCustomerDiscount { - public string Id { get; } = discount.Id; - public bool Active { get; } = discount.Active; - public decimal? PercentOff { get; } = discount.PercentOff; - public List AppliesTo { get; } = discount.AppliesTo; + /// + /// The Stripe coupon ID (e.g., "cm3nHfO1"). + /// + public string? Id { get; } + + /// + /// Whether the discount is a recurring/perpetual discount with no expiration date. + /// + /// This property is true only when the discount has no end date, meaning it applies + /// indefinitely to all future renewals. This is a product decision for Milestone 2 + /// to only display perpetual discounts in the UI. + /// + /// + /// Note: This does NOT indicate whether the discount is "currently active" in the billing sense. + /// A discount with a future end date is functionally active and will be applied by Stripe, + /// but this property will be false because it has an expiration date. + /// + /// + public bool Active { get; } + + /// + /// Percentage discount applied to the subscription (e.g., 20.0 for 20% off). + /// Null if this is an amount-based discount. + /// + public decimal? PercentOff { get; } + + /// + /// Fixed amount discount in USD (e.g., 14.00 for $14 off). + /// Converted from Stripe's cent-based values (1400 cents → $14.00). + /// Null if this is a percentage-based discount. + /// Note: Stripe stores amounts in the smallest currency unit. This value is always in USD. + /// + public decimal? AmountOff { get; } + + /// + /// List of Stripe product IDs that this discount applies to (e.g., ["prod_premium", "prod_families"]). + /// + /// Null: discount applies to all products with no restrictions (AppliesTo not specified in Stripe). + /// Empty list: discount restricted to zero products (edge case - AppliesTo.Products = [] in Stripe). + /// Non-empty list: discount applies only to the specified product IDs. + /// + /// + public IReadOnlyList? AppliesTo { get; } + + /// + /// Creates a BillingCustomerDiscount from a SubscriptionInfo.BillingCustomerDiscount. + /// + /// The discount to convert. Must not be null. + /// Thrown when discount is null. + public BillingCustomerDiscount(SubscriptionInfo.BillingCustomerDiscount discount) + { + ArgumentNullException.ThrowIfNull(discount); + + Id = discount.Id; + Active = discount.Active; + PercentOff = discount.PercentOff; + AmountOff = discount.AmountOff; + AppliesTo = discount.AppliesTo; + } } public class BillingSubscription @@ -83,10 +227,10 @@ public class BillingSubscription public DateTime? PeriodEndDate { get; set; } public DateTime? CancelledDate { get; set; } public bool CancelAtEndDate { get; set; } - public string Status { get; set; } + public string? Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); - public string CollectionMethod { get; set; } + public string? CollectionMethod { get; set; } public DateTime? SuspensionDate { get; set; } public DateTime? UnpaidPeriodEndDate { get; set; } public int? GracePeriod { get; set; } @@ -104,11 +248,11 @@ public class BillingSubscription AddonSubscriptionItem = item.AddonSubscriptionItem; } - public string ProductId { get; set; } - public string Name { get; set; } + public string? ProductId { get; set; } + public string? Name { get; set; } public decimal Amount { get; set; } public int Quantity { get; set; } - public string Interval { get; set; } + public string? Interval { get; set; } public bool SponsoredSubscriptionItem { get; set; } public bool AddonSubscriptionItem { get; set; } } diff --git a/src/Api/Program.cs b/src/Api/Program.cs index 6023f51c6d..bf924af47f 100644 --- a/src/Api/Program.cs +++ b/src/Api/Program.cs @@ -1,9 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using AspNetCoreRateLimit; -using Bit.Core.Utilities; -using Microsoft.IdentityModel.Tokens; +using Bit.Core.Utilities; namespace Bit.Api; @@ -17,32 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (e.Exception != null && - (e.Exception.GetType() == typeof(SecurityTokenValidationException) || - e.Exception.Message == "Bad security stamp.")) - { - return false; - } - - if ( - context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return e.Level >= globalSettings.MinLogLevel.ApiSettings.IpRateLimit; - } - - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.ApiSettings.IdentityToken; - } - - return e.Level >= globalSettings.MinLogLevel.ApiSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Api/Public/Controllers/CollectionsController.cs b/src/Api/Public/Controllers/CollectionsController.cs index 8615113906..a567062a5e 100644 --- a/src/Api/Public/Controllers/CollectionsController.cs +++ b/src/Api/Public/Controllers/CollectionsController.cs @@ -65,10 +65,11 @@ public class CollectionsController : Controller [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] public async Task List() { - var collections = await _collectionRepository.GetManySharedCollectionsByOrganizationIdAsync( - _currentContext.OrganizationId.Value); - // TODO: Get all CollectionGroup associations for the organization and marry them up here for the response. - var collectionResponses = collections.Select(c => new CollectionResponseModel(c, null)); + var collections = await _collectionRepository.GetManyByOrganizationIdWithAccessAsync(_currentContext.OrganizationId.Value); + + var collectionResponses = collections.Select(c => + new CollectionResponseModel(c.Item1, c.Item2.Groups)); + var response = new ListResponseModel(collectionResponses); return new JsonResult(response); } diff --git a/src/Api/SecretsManager/Controllers/SecretVersionsController.cs b/src/Api/SecretsManager/Controllers/SecretVersionsController.cs new file mode 100644 index 0000000000..86e2d1f7e9 --- /dev/null +++ b/src/Api/SecretsManager/Controllers/SecretVersionsController.cs @@ -0,0 +1,337 @@ +using Bit.Api.Models.Response; +using Bit.Api.SecretsManager.Models.Request; +using Bit.Api.SecretsManager.Models.Response; +using Bit.Core.Auth.Identity; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.SecretsManager.Controllers; + +[Authorize("secrets")] +public class SecretVersionsController : Controller +{ + private readonly ICurrentContext _currentContext; + private readonly ISecretVersionRepository _secretVersionRepository; + private readonly ISecretRepository _secretRepository; + private readonly IUserService _userService; + private readonly IOrganizationUserRepository _organizationUserRepository; + + public SecretVersionsController( + ICurrentContext currentContext, + ISecretVersionRepository secretVersionRepository, + ISecretRepository secretRepository, + IUserService userService, + IOrganizationUserRepository organizationUserRepository) + { + _currentContext = currentContext; + _secretVersionRepository = secretVersionRepository; + _secretRepository = secretRepository; + _userService = userService; + _organizationUserRepository = organizationUserRepository; + } + + [HttpGet("secrets/{secretId}/versions")] + public async Task> GetVersionsBySecretIdAsync([FromRoute] Guid secretId) + { + var secret = await _secretRepository.GetByIdAsync(secretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access above + var versionList = await _secretVersionRepository.GetManyBySecretIdAsync(secretId); + var responseList = versionList.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(responseList); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretId, userId.Value, accessClient); + if (!access.Read) + { + throw new NotFoundException(); + } + + var versions = await _secretVersionRepository.GetManyBySecretIdAsync(secretId); + var responses = versions.Select(v => new SecretVersionResponseModel(v)); + + return new ListResponseModel(responses); + } + + [HttpGet("secret-versions/{id}")] + public async Task GetByIdAsync([FromRoute] Guid id) + { + var secretVersion = await _secretVersionRepository.GetByIdAsync(id); + if (secretVersion == null) + { + throw new NotFoundException(); + } + + var secret = await _secretRepository.GetByIdAsync(secretVersion.SecretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access above + return new SecretVersionResponseModel(secretVersion); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretVersion.SecretId, userId.Value, accessClient); + if (!access.Read) + { + throw new NotFoundException(); + } + + return new SecretVersionResponseModel(secretVersion); + } + + [HttpPost("secret-versions/get-by-ids")] + public async Task> GetManyByIdsAsync([FromBody] List ids) + { + if (!ids.Any()) + { + throw new BadRequestException("No version IDs provided."); + } + + // Get all versions + var versions = (await _secretVersionRepository.GetManyByIdsAsync(ids)).ToList(); + if (!versions.Any()) + { + throw new NotFoundException(); + } + + // Get all associated secrets and check permissions + var secretIds = versions.Select(v => v.SecretId).Distinct().ToList(); + var secrets = (await _secretRepository.GetManyByIds(secretIds)).ToList(); + + if (!secrets.Any()) + { + throw new NotFoundException(); + } + + // Ensure all secrets belong to the same organization + var organizationId = secrets.First().OrganizationId; + if (secrets.Any(s => s.OrganizationId != organizationId) || + !_currentContext.AccessSecretsManager(organizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access and organization ownership above + var serviceAccountResponses = versions.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(serviceAccountResponses); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var isAdmin = await _currentContext.OrganizationAdmin(organizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, isAdmin); + + // Verify read access to all associated secrets + var accessResults = await _secretRepository.AccessToSecretsAsync(secretIds, userId.Value, accessClient); + if (accessResults.Values.Any(access => !access.Read)) + { + throw new NotFoundException(); + } + + var responses = versions.Select(v => new SecretVersionResponseModel(v)); + return new ListResponseModel(responses); + } + + [HttpPut("secrets/{secretId}/versions/restore")] + public async Task RestoreVersionAsync([FromRoute] Guid secretId, [FromBody] RestoreSecretVersionRequestModel request) + { + if (!(_currentContext.IdentityClientType == IdentityClientType.User || _currentContext.IdentityClientType == IdentityClientType.ServiceAccount)) + { + throw new NotFoundException(); + } + + var secret = await _secretRepository.GetByIdAsync(secretId); + if (secret == null || !_currentContext.AccessSecretsManager(secret.OrganizationId)) + { + throw new NotFoundException(); + } + + // Get the version first to validate it belongs to this secret + var version = await _secretVersionRepository.GetByIdAsync(request.VersionId); + if (version == null || version.SecretId != secretId) + { + throw new NotFoundException(); + } + + // Store the current value before restoration + var currentValue = secret.Value; + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount) + { + // Save current value as a version before restoring + if (currentValue != version.Value) + { + var editorUserId = _userService.GetProperUserId(User); + if (editorUserId.HasValue) + { + var currentVersionSnapshot = new Core.SecretsManager.Entities.SecretVersion + { + SecretId = secretId, + Value = currentValue!, + VersionDate = DateTime.UtcNow, + EditorServiceAccountId = editorUserId.Value + }; + + await _secretVersionRepository.CreateAsync(currentVersionSnapshot); + } + } + + // Already verified Secrets Manager access above + secret.Value = version.Value; + secret.RevisionDate = DateTime.UtcNow; + var updatedSec = await _secretRepository.UpdateAsync(secret); + return new SecretResponseModel(updatedSec, true, true); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(secret.OrganizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + var access = await _secretRepository.AccessToSecretAsync(secretId, userId.Value, accessClient); + if (!access.Write) + { + throw new NotFoundException(); + } + + // Save current value as a version before restoring + if (currentValue != version.Value) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(secret.OrganizationId, userId.Value); + if (orgUser == null) + { + throw new NotFoundException(); + } + + var currentVersionSnapshot = new Core.SecretsManager.Entities.SecretVersion + { + SecretId = secretId, + Value = currentValue!, + VersionDate = DateTime.UtcNow, + EditorOrganizationUserId = orgUser.Id + }; + + await _secretVersionRepository.CreateAsync(currentVersionSnapshot); + } + + // Update the secret with the version's value + secret.Value = version.Value; + secret.RevisionDate = DateTime.UtcNow; + + var updatedSecret = await _secretRepository.UpdateAsync(secret); + + return new SecretResponseModel(updatedSecret, true, true); + } + + [HttpPost("secret-versions/delete")] + public async Task BulkDeleteAsync([FromBody] List ids) + { + if (!ids.Any()) + { + throw new BadRequestException("No version IDs provided."); + } + + var secretVersions = (await _secretVersionRepository.GetManyByIdsAsync(ids)).ToList(); + if (secretVersions.Count != ids.Count) + { + throw new NotFoundException(); + } + + // Ensure all versions belong to secrets in the same organization + var secretIds = secretVersions.Select(v => v.SecretId).Distinct().ToList(); + var secrets = await _secretRepository.GetManyByIds(secretIds); + var secretsList = secrets.ToList(); + + if (!secretsList.Any()) + { + throw new NotFoundException(); + } + + var organizationId = secretsList.First().OrganizationId; + if (secretsList.Any(s => s.OrganizationId != organizationId) || + !_currentContext.AccessSecretsManager(organizationId)) + { + throw new NotFoundException(); + } + + // For service accounts and organization API, skip user-level access checks + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount || + _currentContext.IdentityClientType == IdentityClientType.Organization) + { + // Already verified Secrets Manager access and organization ownership above + await _secretVersionRepository.DeleteManyByIdAsync(ids); + return Ok(); + } + + var userId = _userService.GetProperUserId(User); + if (!userId.HasValue) + { + throw new NotFoundException(); + } + + var orgAdmin = await _currentContext.OrganizationAdmin(organizationId); + var accessClient = AccessClientHelper.ToAccessClient(_currentContext.IdentityClientType, orgAdmin); + + // Verify write access to all associated secrets + var accessResults = await _secretRepository.AccessToSecretsAsync(secretIds, userId.Value, accessClient); + if (accessResults.Values.Any(access => !access.Write)) + { + throw new NotFoundException(); + } + + await _secretVersionRepository.DeleteManyByIdAsync(ids); + + return Ok(); + } +} diff --git a/src/Api/SecretsManager/Controllers/SecretsController.cs b/src/Api/SecretsManager/Controllers/SecretsController.cs index e263b9747d..dcfe1be111 100644 --- a/src/Api/SecretsManager/Controllers/SecretsController.cs +++ b/src/Api/SecretsManager/Controllers/SecretsController.cs @@ -8,6 +8,7 @@ using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Repositories; using Bit.Core.SecretsManager.AuthorizationRequirements; using Bit.Core.SecretsManager.Commands.Secrets.Interfaces; using Bit.Core.SecretsManager.Entities; @@ -29,6 +30,7 @@ public class SecretsController : Controller private readonly ICurrentContext _currentContext; private readonly IProjectRepository _projectRepository; private readonly ISecretRepository _secretRepository; + private readonly ISecretVersionRepository _secretVersionRepository; private readonly ICreateSecretCommand _createSecretCommand; private readonly IUpdateSecretCommand _updateSecretCommand; private readonly IDeleteSecretCommand _deleteSecretCommand; @@ -38,11 +40,13 @@ public class SecretsController : Controller private readonly IUserService _userService; private readonly IEventService _eventService; private readonly IAuthorizationService _authorizationService; + private readonly IOrganizationUserRepository _organizationUserRepository; public SecretsController( ICurrentContext currentContext, IProjectRepository projectRepository, ISecretRepository secretRepository, + ISecretVersionRepository secretVersionRepository, ICreateSecretCommand createSecretCommand, IUpdateSecretCommand updateSecretCommand, IDeleteSecretCommand deleteSecretCommand, @@ -51,11 +55,13 @@ public class SecretsController : Controller ISecretAccessPoliciesUpdatesQuery secretAccessPoliciesUpdatesQuery, IUserService userService, IEventService eventService, - IAuthorizationService authorizationService) + IAuthorizationService authorizationService, + IOrganizationUserRepository organizationUserRepository) { _currentContext = currentContext; _projectRepository = projectRepository; _secretRepository = secretRepository; + _secretVersionRepository = secretVersionRepository; _createSecretCommand = createSecretCommand; _updateSecretCommand = updateSecretCommand; _deleteSecretCommand = deleteSecretCommand; @@ -65,6 +71,7 @@ public class SecretsController : Controller _userService = userService; _eventService = eventService; _authorizationService = authorizationService; + _organizationUserRepository = organizationUserRepository; } @@ -190,6 +197,44 @@ public class SecretsController : Controller } } + // Create a version record if the value changed + if (updateRequest.ValueChanged) + { + // Store the old value before updating + var oldValue = secret.Value; + var userId = _userService.GetProperUserId(User)!.Value; + Guid? editorServiceAccountId = null; + Guid? editorOrganizationUserId = null; + + if (_currentContext.IdentityClientType == IdentityClientType.ServiceAccount) + { + editorServiceAccountId = userId; + } + else if (_currentContext.IdentityClientType == IdentityClientType.User) + { + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(secret.OrganizationId, userId); + if (orgUser != null) + { + editorOrganizationUserId = orgUser.Id; + } + else + { + throw new NotFoundException(); + } + } + + var secretVersion = new SecretVersion + { + SecretId = id, + Value = oldValue, + VersionDate = DateTime.UtcNow, + EditorServiceAccountId = editorServiceAccountId, + EditorOrganizationUserId = editorOrganizationUserId + }; + + await _secretVersionRepository.CreateAsync(secretVersion); + } + var result = await _updateSecretCommand.UpdateAsync(updatedSecret, accessPoliciesUpdates); await LogSecretEventAsync(secret, EventType.Secret_Edited); diff --git a/src/Api/SecretsManager/Controllers/SecretsManagerEventsController.cs b/src/Api/SecretsManager/Controllers/SecretsManagerEventsController.cs index af162fe399..0f467a4c78 100644 --- a/src/Api/SecretsManager/Controllers/SecretsManagerEventsController.cs +++ b/src/Api/SecretsManager/Controllers/SecretsManagerEventsController.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Api.Dirt.Models.Response; using Bit.Api.Models.Response; using Bit.Api.Utilities; using Bit.Core.Exceptions; diff --git a/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs b/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs new file mode 100644 index 0000000000..19a6b35a75 --- /dev/null +++ b/src/Api/SecretsManager/Models/Request/RestoreSecretVersionRequestModel.cs @@ -0,0 +1,9 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Api.SecretsManager.Models.Request; + +public class RestoreSecretVersionRequestModel +{ + [Required] + public Guid VersionId { get; set; } +} diff --git a/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs b/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs index b95bc9e500..9d19e1d8cc 100644 --- a/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs +++ b/src/Api/SecretsManager/Models/Request/SecretUpdateRequestModel.cs @@ -28,6 +28,8 @@ public class SecretUpdateRequestModel : IValidatableObject public SecretAccessPoliciesRequestsModel AccessPoliciesRequests { get; set; } + public bool ValueChanged { get; set; } = false; + public Secret ToSecret(Secret secret) { secret.Key = Key; diff --git a/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs b/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs new file mode 100644 index 0000000000..07b8e88f7e --- /dev/null +++ b/src/Api/SecretsManager/Models/Response/SecretVersionResponseModel.cs @@ -0,0 +1,28 @@ +using Bit.Core.Models.Api; +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Api.SecretsManager.Models.Response; + +public class SecretVersionResponseModel : ResponseModel +{ + private const string _objectName = "secretVersion"; + + public Guid Id { get; set; } + public Guid SecretId { get; set; } + public string Value { get; set; } = string.Empty; + public DateTime VersionDate { get; set; } + public Guid? EditorServiceAccountId { get; set; } + public Guid? EditorOrganizationUserId { get; set; } + + public SecretVersionResponseModel() : base(_objectName) { } + + public SecretVersionResponseModel(SecretVersion secretVersion) : base(_objectName) + { + Id = secretVersion.Id; + SecretId = secretVersion.SecretId; + Value = secretVersion.Value; + VersionDate = secretVersion.VersionDate; + EditorServiceAccountId = secretVersion.EditorServiceAccountId; + EditorOrganizationUserId = secretVersion.EditorOrganizationUserId; + } +} diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index cc50a1b362..2f16470cd4 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -94,9 +94,6 @@ public class Startup services.AddMemoryCache(); services.AddDistributedCache(globalSettings); - // BitPay - services.AddSingleton(); - if (!globalSettings.SelfHosted) { services.AddIpRateLimiting(globalSettings); @@ -190,7 +187,6 @@ public class Startup services.AddBillingOperations(); services.AddReportingServices(); services.AddImportServices(); - services.AddPhishingDomainServices(globalSettings); services.AddSendServices(); @@ -219,7 +215,7 @@ public class Startup config.Conventions.Add(new PublicApiControllersModelConvention()); }); - services.AddSwagger(globalSettings, Environment); + services.AddSwaggerGen(globalSettings, Environment); Jobs.JobsHostedService.AddJobsServices(services, globalSettings.SelfHosted); services.AddHostedService(); @@ -229,19 +225,19 @@ public class Startup services.AddHostedService(); } - // Add SlackService for OAuth API requests - if configured + // Add Event Integrations services + services.AddEventIntegrationsCommandsQueries(globalSettings); services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); } public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings, ILogger logger) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); @@ -296,17 +292,59 @@ public class Startup }); // Add Swagger + // Note that the swagger.json generation is configured in the call to AddSwaggerGen above. if (Environment.IsDevelopment() || globalSettings.SelfHosted) { + // adds the middleware to serve the swagger.json while the server is running app.UseSwagger(config => { config.RouteTemplate = "specs/{documentName}/swagger.json"; + + // Remove all Bitwarden cloud servers and only register the local server config.PreSerializeFilters.Add((swaggerDoc, httpReq) => - swaggerDoc.Servers = new List + { + swaggerDoc.Servers.Clear(); + swaggerDoc.Servers.Add(new OpenApiServer { - new OpenApiServer { Url = globalSettings.BaseServiceUri.Api } + Url = globalSettings.BaseServiceUri.Api, }); + + swaggerDoc.Components.SecuritySchemes.Clear(); + swaggerDoc.Components.SecuritySchemes.Add("oauth2-client-credentials", new OpenApiSecurityScheme + { + Type = SecuritySchemeType.OAuth2, + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), + Scopes = new Dictionary + { + { ApiScopes.ApiOrganization, "Organization APIs" } + } + } + } + }); + + swaggerDoc.SecurityRequirements.Clear(); + swaggerDoc.SecurityRequirements.Add(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = "oauth2-client-credentials" + } + }, + [ApiScopes.ApiOrganization] + } + }); + }); }); + + // adds the middleware to display the web UI app.UseSwaggerUI(config => { config.DocumentTitle = "Bitwarden API Documentation"; @@ -325,6 +363,6 @@ public class Startup } // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/src/Api/Tools/Controllers/ImportCiphersController.cs b/src/Api/Tools/Controllers/ImportCiphersController.cs index 88028420b7..8b3ec5e26c 100644 --- a/src/Api/Tools/Controllers/ImportCiphersController.cs +++ b/src/Api/Tools/Controllers/ImportCiphersController.cs @@ -74,10 +74,14 @@ public class ImportCiphersController : Controller throw new BadRequestException("You cannot import this much data at once."); } + if (model.Ciphers.Any(c => c.ArchivedDate.HasValue)) + { + throw new BadRequestException("You cannot import archived items into an organization."); + } + var orgId = new Guid(organizationId); var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); - //An User is allowed to import if CanCreate Collections or has AccessToImportExport var authorized = await CheckOrgImportPermission(collections, orgId); if (!authorized) @@ -156,7 +160,7 @@ public class ImportCiphersController : Controller if (existingCollections.Any() && (await _authorizationService.AuthorizeAsync(User, existingCollections, BulkCollectionOperations.ImportCiphers)).Succeeded) { return true; - }; + } return false; } diff --git a/src/Api/Tools/Controllers/SendsController.cs b/src/Api/Tools/Controllers/SendsController.cs index c02e9b0c20..c54a9b90c9 100644 --- a/src/Api/Tools/Controllers/SendsController.cs +++ b/src/Api/Tools/Controllers/SendsController.cs @@ -166,7 +166,7 @@ public class SendsController : Controller } catch (Exception e) { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + _logger.LogError(e, "Uncaught exception occurred while handling event grid event: {Event}", JsonSerializer.Serialize(eventGridEvent)); return; } } diff --git a/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs new file mode 100644 index 0000000000..af34931181 --- /dev/null +++ b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs @@ -0,0 +1,87 @@ +using Bit.Api.Dirt.Public.Models; +using Bit.Api.Models.Public.Response; +using Bit.Core; +using Bit.Core.Services; + +namespace Bit.Api.Utilities.DiagnosticTools; + +public static class EventDiagnosticLogger +{ + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + PagedListResponseModel data, EventFilterRequestModel request) + { + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.Data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(data.ContinuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd} ActingUserId:{ActingUserId} ItemId:{ItemId},", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + request.Start?.ToString("o"), + request.End?.ToString("o"), + request.ActingUserId, + request.ItemId); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } + + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + IEnumerable data, + string? continuationToken, + DateTime? queryStart = null, + DateTime? queryEnd = null) + { + + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(continuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd}", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + queryStart?.ToString("o"), + queryEnd?.ToString("o")); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } +} diff --git a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs index 91079d5040..1caa7cf841 100644 --- a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs @@ -152,7 +152,7 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute else { var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); errorMessage = "An unhandled server error has occurred."; context.HttpContext.Response.StatusCode = 500; } diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index 6af688f548..b773abf6ef 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -1,15 +1,11 @@ using Bit.Api.AdminConsole.Authorization; using Bit.Api.Tools.Authorization; -using Bit.Core.Auth.IdentityServer; -using Bit.Core.PhishingDomainFeatures; -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Repositories.Implementations; using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.Core.Vault.Authorization.SecurityTasks; using Bit.SharedWeb.Health; using Bit.SharedWeb.Swagger; +using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.OpenApi.Models; @@ -17,7 +13,10 @@ namespace Bit.Api.Utilities; public static class ServiceCollectionExtensions { - public static void AddSwagger(this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment) + /// + /// Configures the generation of swagger.json OpenAPI spec. + /// + public static void AddSwaggerGen(this IServiceCollection services, GlobalSettings globalSettings, IWebHostEnvironment environment) { services.AddSwaggerGen(config => { @@ -36,6 +35,8 @@ public static class ServiceCollectionExtensions organizations tools for managing members, collections, groups, event logs, and policies. If you are looking for the Vault Management API, refer instead to [this document](https://bitwarden.com/help/vault-management-api/). + + **Note:** your authorization must match the server you have selected. """, License = new OpenApiLicense { @@ -46,36 +47,20 @@ public static class ServiceCollectionExtensions config.SwaggerDoc("internal", new OpenApiInfo { Title = "Bitwarden Internal API", Version = "latest" }); - config.AddSecurityDefinition("oauth2-client-credentials", new OpenApiSecurityScheme - { - Type = SecuritySchemeType.OAuth2, - Flows = new OpenApiOAuthFlows - { - ClientCredentials = new OpenApiOAuthFlow - { - TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), - Scopes = new Dictionary - { - { ApiScopes.ApiOrganization, "Organization APIs" }, - }, - } - }, - }); + // Configure Bitwarden cloud US and EU servers. These will appear in the swagger.json build artifact + // used for our help center. These are overwritten with the local server when running in self-hosted + // or dev mode (see Api Startup.cs). + config.AddSwaggerServerWithSecurity( + serverId: "US_server", + serverUrl: "https://api.bitwarden.com", + identityTokenUrl: "https://identity.bitwarden.com/connect/token", + serverDescription: "US server"); - config.AddSecurityRequirement(new OpenApiSecurityRequirement - { - { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = "oauth2-client-credentials" - }, - }, - new[] { ApiScopes.ApiOrganization } - } - }); + config.AddSwaggerServerWithSecurity( + serverId: "EU_server", + serverUrl: "https://api.bitwarden.eu", + identityTokenUrl: "https://identity.bitwarden.eu/connect/token", + serverDescription: "EU server"); config.DescribeAllParametersInCamelCase(); // config.UseReferencedDefinitionsForEnums(); @@ -114,25 +99,4 @@ public static class ServiceCollectionExtensions // Admin Console authorization handlers services.AddAdminConsoleAuthorizationHandlers(); } - - public static void AddPhishingDomainServices(this IServiceCollection services, GlobalSettings globalSettings) - { - services.AddHttpClient("PhishingDomains", client => - { - client.DefaultRequestHeaders.Add("User-Agent", globalSettings.SelfHosted ? "Bitwarden Self-Hosted" : "Bitwarden"); - client.Timeout = TimeSpan.FromSeconds(1000); // the source list is very slow - }); - - services.AddSingleton(); - services.AddSingleton(); - - if (globalSettings.SelfHosted) - { - services.AddScoped(); - } - else - { - services.AddScoped(); - } - } } diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index 06c88ad9bb..6a506cc01f 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using System.Globalization; using System.Text.Json; using Azure.Messaging.EventGrid; using Bit.Api.Auth.Models.Request.Accounts; @@ -401,8 +402,9 @@ public class CiphersController : Controller { var org = _currentContext.GetOrganization(organizationId); - // If we're not an "admin" or if we're not a provider user we don't need to check the ciphers - if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId)) + // If we're not an "admin" we don't need to check the ciphers + if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or + { Permissions.EditAnyCollection: true })) { return false; } @@ -415,8 +417,9 @@ public class CiphersController : Controller { var org = _currentContext.GetOrganization(organizationId); - // If we're not an "admin" or if we're a provider user we don't need to check the ciphers - if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId)) + // If we're not an "admin" we don't need to check the ciphers + if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or + { Permissions.EditAnyCollection: true })) { return false; } @@ -757,7 +760,7 @@ public class CiphersController : Controller ValidateClientVersionForFido2CredentialSupport(cipher); var original = cipher.Clone(); - await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher), new Guid(model.Cipher.OrganizationId), + await _cipherService.ShareAsync(original, model.Cipher.ToCipher(cipher, user.Id), new Guid(model.Cipher.OrganizationId), model.CollectionIds.Select(c => new Guid(c)), user.Id, model.Cipher.LastKnownRevisionDate); var sharedCipher = await GetByIdAsync(id, user.Id); @@ -1351,7 +1354,7 @@ public class CiphersController : Controller } var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, - request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id); + request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id, request.LastKnownRevisionDate); return new AttachmentUploadDataResponseModel { AttachmentId = attachmentId, @@ -1425,10 +1428,12 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), user.Id); + Request.ContentLength.GetValueOrDefault(0), user.Id, false, lastKnownRevisionDate); }); return new CipherResponseModel( @@ -1454,10 +1459,13 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); + await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId, true); + Request.ContentLength.GetValueOrDefault(0), userId, true, lastKnownRevisionDate); }); return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); @@ -1578,7 +1586,7 @@ public class CiphersController : Controller } catch (Exception e) { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + _logger.LogError(e, "Uncaught exception occurred while handling event grid event: {Event}", JsonSerializer.Serialize(eventGridEvent)); return; } } @@ -1615,4 +1623,19 @@ public class CiphersController : Controller { return await _cipherRepository.GetByIdAsync(cipherId, userId); } + + private DateTime? GetLastKnownRevisionDateFromForm() + { + DateTime? lastKnownRevisionDate = null; + if (Request.Form.TryGetValue("lastKnownRevisionDate", out var dateValue)) + { + if (!DateTime.TryParse(dateValue, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out var parsedDate)) + { + throw new BadRequestException("Invalid lastKnownRevisionDate format."); + } + lastKnownRevisionDate = parsedDate; + } + + return lastKnownRevisionDate; + } } diff --git a/src/Api/Vault/Controllers/SyncController.cs b/src/Api/Vault/Controllers/SyncController.cs index 54f1b9e70b..6ac8d06ba0 100644 --- a/src/Api/Vault/Controllers/SyncController.cs +++ b/src/Api/Vault/Controllers/SyncController.cs @@ -11,6 +11,8 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Services; @@ -42,6 +44,7 @@ public class SyncController : Controller private readonly IFeatureService _featureService; private readonly IApplicationCacheService _applicationCacheService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; public SyncController( IUserService userService, @@ -57,7 +60,8 @@ public class SyncController : Controller ICurrentContext currentContext, IFeatureService featureService, IApplicationCacheService applicationCacheService, - ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery) + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IUserAccountKeysQuery userAccountKeysQuery) { _userService = userService; _folderRepository = folderRepository; @@ -73,6 +77,7 @@ public class SyncController : Controller _featureService = featureService; _applicationCacheService = applicationCacheService; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; + _userAccountKeysQuery = userAccountKeysQuery; } [HttpGet("")] @@ -116,7 +121,14 @@ public class SyncController : Controller var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities, + UserAccountKeysData userAccountKeys = null; + // JIT TDE users and some broken/old users may not have a private key. + if (!string.IsNullOrWhiteSpace(user.PrivateKey)) + { + userAccountKeys = await _userAccountKeysQuery.Run(user); + } + + var response = new SyncResponseModel(_globalSettings, user, userAccountKeys, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities, organizationIdsClaimingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends); return response; diff --git a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs index 96c66c6044..eef70bf4e4 100644 --- a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs +++ b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs @@ -9,4 +9,9 @@ public class AttachmentRequestModel public string FileName { get; set; } public long FileSize { get; set; } public bool AdminRequest { get; set; } = false; + + /// + /// The last known revision date of the Cipher that this attachment belongs to. + /// + public DateTime? LastKnownRevisionDate { get; set; } } diff --git a/src/Api/Vault/Models/Request/CipherRequestModel.cs b/src/Api/Vault/Models/Request/CipherRequestModel.cs index b0589a62f9..18a1aec559 100644 --- a/src/Api/Vault/Models/Request/CipherRequestModel.cs +++ b/src/Api/Vault/Models/Request/CipherRequestModel.cs @@ -84,7 +84,7 @@ public class CipherRequestModel return existingCipher; } - public Cipher ToCipher(Cipher existingCipher) + public Cipher ToCipher(Cipher existingCipher, Guid? userId = null) { // If Data field is provided, use it directly if (!string.IsNullOrWhiteSpace(Data)) @@ -124,9 +124,12 @@ public class CipherRequestModel } } + var userIdKey = userId.HasValue ? userId.ToString().ToUpperInvariant() : null; existingCipher.Reprompt = Reprompt; existingCipher.Key = Key; existingCipher.ArchivedDate = ArchivedDate; + existingCipher.Folders = UpdateUserSpecificJsonField(existingCipher.Folders, userIdKey, FolderId); + existingCipher.Favorites = UpdateUserSpecificJsonField(existingCipher.Favorites, userIdKey, Favorite); var hasAttachments2 = (Attachments2?.Count ?? 0) > 0; var hasAttachments = (Attachments?.Count ?? 0) > 0; @@ -291,6 +294,37 @@ public class CipherRequestModel KeyFingerprint = SSHKey.KeyFingerprint, }; } + + /// + /// Updates a JSON string representing a dictionary by adding, updating, or removing a key-value pair + /// based on the provided userIdKey and newValue. + /// + private static string UpdateUserSpecificJsonField(string existingJson, string userIdKey, object newValue) + { + if (userIdKey == null) + { + return existingJson; + } + + var jsonDict = string.IsNullOrWhiteSpace(existingJson) + ? new Dictionary() + : JsonSerializer.Deserialize>(existingJson) ?? new Dictionary(); + + var shouldRemove = newValue == null || + (newValue is string strValue && string.IsNullOrWhiteSpace(strValue)) || + (newValue is bool boolValue && !boolValue); + + if (shouldRemove) + { + jsonDict.Remove(userIdKey); + } + else + { + jsonDict[userIdKey] = newValue is string str ? str.ToUpperInvariant() : newValue; + } + + return jsonDict.Count == 0 ? null : JsonSerializer.Serialize(jsonDict); + } } public class CipherWithIdRequestModel : CipherRequestModel diff --git a/src/Api/Vault/Models/Response/SyncResponseModel.cs b/src/Api/Vault/Models/Response/SyncResponseModel.cs index e19defce51..1981ac834e 100644 --- a/src/Api/Vault/Models/Response/SyncResponseModel.cs +++ b/src/Api/Vault/Models/Response/SyncResponseModel.cs @@ -7,7 +7,8 @@ using Bit.Api.Tools.Models.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.Entities; -using Bit.Core.KeyManagement.Models.Response; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; @@ -24,6 +25,7 @@ public class SyncResponseModel() : ResponseModel("sync") public SyncResponseModel( GlobalSettings globalSettings, User user, + UserAccountKeysData userAccountKeysData, bool userTwoFactorEnabled, bool userHasPremiumFromOrganization, IDictionary organizationAbilities, @@ -40,7 +42,7 @@ public class SyncResponseModel() : ResponseModel("sync") IEnumerable sends) : this() { - Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + Profile = new ProfileResponseModel(user, userAccountKeysData, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingingUser); Folders = folders.Select(f => new FolderResponseModel(f)); Ciphers = ciphers.Select(cipher => diff --git a/src/Api/appsettings.Development.json b/src/Api/appsettings.Development.json index 82fb951261..deb0a35d84 100644 --- a/src/Api/appsettings.Development.json +++ b/src/Api/appsettings.Development.json @@ -38,9 +38,6 @@ "storage": { "connectionString": "UseDevelopmentStorage=true" }, - "phishingDomain": { - "updateUrl": "https://phish.co.za/latest/phishing-domains-ACTIVE.txt", - "checksumUrl": "https://raw.githubusercontent.com/Phishing-Database/checksums/refs/heads/master/phishing-domains-ACTIVE.txt.sha256" - } + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } } diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index f8a69dcfac..8850c3d269 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -32,9 +32,6 @@ "send": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" @@ -64,16 +61,14 @@ "bitPay": { "production": false, "token": "SECRET", - "notificationUrl": "https://bitwarden.com/SECRET" + "notificationUrl": "https://bitwarden.com/SECRET", + "webhookKey": "SECRET" }, "amazon": { "accessKeyId": "SECRET", "accessKeySecret": "SECRET", "region": "SECRET" }, - "phishingDomain": { - "updateUrl": "SECRET" - }, "distributedIpRateLimiting": { "enabled": true, "maxRedisTimeoutsThreshold": 10, diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj index e2b7447eb7..69999dc795 100644 --- a/src/Billing/Billing.csproj +++ b/src/Billing/Billing.csproj @@ -1,9 +1,17 @@  + bitwarden-Billing + + + false + false + false + + @@ -11,7 +19,7 @@ - + diff --git a/src/Billing/BillingSettings.cs b/src/Billing/BillingSettings.cs index 32630e4a4a..2830f603ac 100644 --- a/src/Billing/BillingSettings.cs +++ b/src/Billing/BillingSettings.cs @@ -7,15 +7,9 @@ public class BillingSettings { public virtual string JobsKey { get; set; } public virtual string StripeWebhookKey { get; set; } - public virtual string StripeWebhookSecret { get; set; } - public virtual string StripeWebhookSecret20231016 { get; set; } - public virtual string StripeWebhookSecret20240620 { get; set; } - public virtual string BitPayWebhookKey { get; set; } + public virtual string StripeWebhookSecret20250827Basil { get; set; } public virtual string AppleWebhookKey { get; set; } - public virtual FreshDeskSettings FreshDesk { get; set; } = new FreshDeskSettings(); - public virtual string FreshsalesApiKey { get; set; } public virtual PayPalSettings PayPal { get; set; } = new PayPalSettings(); - public virtual OnyxSettings Onyx { get; set; } = new OnyxSettings(); public class PayPalSettings { @@ -24,26 +18,4 @@ public class BillingSettings public virtual string WebhookKey { get; set; } } - public class FreshDeskSettings - { - public virtual string ApiKey { get; set; } - public virtual string WebhookKey { get; set; } - /// - /// Indicates the data center region. Valid values are "US" and "EU" - /// - public virtual string Region { get; set; } - public virtual string UserFieldName { get; set; } - public virtual string OrgFieldName { get; set; } - - public virtual bool RemoveNewlinesInReplies { get; set; } = false; - public virtual string AutoReplyGreeting { get; set; } = string.Empty; - public virtual string AutoReplySalutation { get; set; } = string.Empty; - } - - public class OnyxSettings - { - public virtual string ApiKey { get; set; } - public virtual string BaseUrl { get; set; } - public virtual int PersonaId { get; set; } - } } diff --git a/src/Billing/Constants/BitPayInvoiceStatus.cs b/src/Billing/Constants/BitPayInvoiceStatus.cs deleted file mode 100644 index b9c1e5834d..0000000000 --- a/src/Billing/Constants/BitPayInvoiceStatus.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Billing.Constants; - -public static class BitPayInvoiceStatus -{ - public const string Confirmed = "confirmed"; - public const string Complete = "complete"; -} diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index 111ffabc2b..f55b4523af 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -1,125 +1,79 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Globalization; -using Bit.Billing.Constants; +using System.Globalization; using Bit.Billing.Models; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Clients; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; +using BitPayLight.Models.Invoice; using Microsoft.AspNetCore.Mvc; using Microsoft.Data.SqlClient; -using Microsoft.Extensions.Options; namespace Bit.Billing.Controllers; +using static BitPayConstants; +using static StripeConstants; + [Route("bitpay")] [ApiExplorerSettings(IgnoreApi = true)] -public class BitPayController : Controller +public class BitPayController( + GlobalSettings globalSettings, + IBitPayClient bitPayClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IProviderRepository providerRepository, + IMailService mailService, + IStripePaymentService paymentService, + ILogger logger, + IPremiumUserBillingService premiumUserBillingService) + : Controller { - private readonly BillingSettings _billingSettings; - private readonly BitPayClient _bitPayClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IProviderRepository _providerRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; - private readonly IPremiumUserBillingService _premiumUserBillingService; - - public BitPayController( - IOptions billingSettings, - BitPayClient bitPayClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IProviderRepository providerRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger, - IPremiumUserBillingService premiumUserBillingService) - { - _billingSettings = billingSettings?.Value; - _bitPayClient = bitPayClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _providerRepository = providerRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; - _premiumUserBillingService = premiumUserBillingService; - } - [HttpPost("ipn")] public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) + if (!CoreHelpers.FixedTimeEquals(key, globalSettings.BitPay.WebhookKey)) { - return new BadRequestResult(); - } - if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || - string.IsNullOrWhiteSpace(model.Event?.Name)) - { - return new BadRequestResult(); + return new BadRequestObjectResult("Invalid key"); } - if (model.Event.Name != BitPayNotificationCode.InvoiceConfirmed) - { - // Only processing confirmed invoice events for now. - return new OkResult(); - } - - var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); - if (invoice == null) - { - // Request forged...? - _logger.LogWarning("Invoice not found. #{InvoiceId}", model.Data.Id); - return new BadRequestResult(); - } - - if (invoice.Status != BitPayInvoiceStatus.Confirmed && invoice.Status != BitPayInvoiceStatus.Complete) - { - _logger.LogWarning("Invoice status of '{InvoiceStatus}' is not acceptable. #{InvoiceId}", invoice.Status, invoice.Id); - return new BadRequestResult(); - } + var invoice = await bitPayClient.GetInvoice(model.Data.Id); if (invoice.Currency != "USD") { - // Only process USD payments - _logger.LogWarning("Non USD payment received. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogWarning("Received BitPay invoice webhook for invoice ({InvoiceID}) with non-USD currency: {Currency}", invoice.Id, invoice.Currency); + return new BadRequestObjectResult("Cannot process non-USD payments"); } var (organizationId, userId, providerId) = GetIdsFromPosData(invoice); - if (!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) + if ((!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) || !invoice.PosData.Contains(PosDataKeys.AccountCredit)) { - return new OkResult(); + logger.LogWarning("Received BitPay invoice webhook for invoice ({InvoiceID}) that had invalid POS data: {PosData}", invoice.Id, invoice.PosData); + return new BadRequestObjectResult("Invalid POS data"); } - var isAccountCredit = IsAccountCredit(invoice); - if (!isAccountCredit) + if (invoice.Status != InvoiceStatuses.Complete) { - // Only processing credits - _logger.LogWarning("Non-credit payment received. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogInformation("Received valid BitPay invoice webhook for invoice ({InvoiceID}) that is not yet complete: {Status}", + invoice.Id, invoice.Status); + return new OkObjectResult("Waiting for invoice to be completed"); } - var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); - if (transaction != null) + var existingTransaction = await transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); + if (existingTransaction != null) { - _logger.LogWarning("Already processed this invoice. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogWarning("Already processed BitPay invoice webhook for invoice ({InvoiceID})", invoice.Id); + return new OkObjectResult("Invoice already processed"); } try { - var tx = new Transaction + var transaction = new Transaction { Amount = Convert.ToDecimal(invoice.Price), CreationDate = GetTransactionDate(invoice), @@ -132,50 +86,47 @@ public class BitPayController : Controller PaymentMethodType = PaymentMethodType.BitPay, Details = $"{invoice.Currency}, BitPay {invoice.Id}" }; - await _transactionRepository.CreateAsync(tx); - string billingEmail = null; - if (tx.OrganizationId.HasValue) + await transactionRepository.CreateAsync(transaction); + + var billingEmail = ""; + if (transaction.OrganizationId.HasValue) { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) + var organization = await organizationRepository.GetByIdAsync(transaction.OrganizationId.Value); + if (organization != null) { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) + billingEmail = organization.BillingEmailAddress(); + if (await paymentService.CreditAccountAsync(organization, transaction.Amount)) { - await _organizationRepository.ReplaceAsync(org); + await organizationRepository.ReplaceAsync(organization); } } } - else if (tx.UserId.HasValue) + else if (transaction.UserId.HasValue) { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + var user = await userRepository.GetByIdAsync(transaction.UserId.Value); if (user != null) { billingEmail = user.BillingEmailAddress(); - await _premiumUserBillingService.Credit(user, tx.Amount); + await premiumUserBillingService.Credit(user, transaction.Amount); } } - else if (tx.ProviderId.HasValue) + else if (transaction.ProviderId.HasValue) { - var provider = await _providerRepository.GetByIdAsync(tx.ProviderId.Value); + var provider = await providerRepository.GetByIdAsync(transaction.ProviderId.Value); if (provider != null) { billingEmail = provider.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(provider, tx.Amount)) + if (await paymentService.CreditAccountAsync(provider, transaction.Amount)) { - await _providerRepository.ReplaceAsync(provider); + await providerRepository.ReplaceAsync(provider); } } } - else - { - _logger.LogError("Received BitPay account credit transaction that didn't have a user, org, or provider. Invoice#{InvoiceId}", invoice.Id); - } if (!string.IsNullOrWhiteSpace(billingEmail)) { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); + await mailService.SendAddedCreditAsync(billingEmail, transaction.Amount); } } // Catch foreign key violations because user/org could have been deleted. @@ -186,58 +137,34 @@ public class BitPayController : Controller return new OkResult(); } - private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) + private static DateTime GetTransactionDate(Invoice invoice) { - return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); + var transactions = invoice.Transactions?.Where(transaction => + transaction.Type == null && !string.IsNullOrWhiteSpace(transaction.Confirmations) && + transaction.Confirmations != "0").ToList(); + + return transactions?.Count == 1 + ? DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind) + : CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); } - private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) + public (Guid? OrganizationId, Guid? UserId, Guid? ProviderId) GetIdsFromPosData(Invoice invoice) { - var transactions = invoice.Transactions?.Where(t => t.Type == null && - !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); - if (transactions != null && transactions.Count() == 1) + if (invoice.PosData is null or { Length: 0 } || !invoice.PosData.Contains(':')) { - return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, - DateTimeStyles.RoundtripKind); - } - return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); - } - - public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) - { - Guid? orgId = null; - Guid? userId = null; - Guid? providerId = null; - - if (invoice == null || string.IsNullOrWhiteSpace(invoice.PosData) || !invoice.PosData.Contains(':')) - { - return new Tuple(null, null, null); + return new ValueTuple(null, null, null); } - var mainParts = invoice.PosData.Split(','); - foreach (var mainPart in mainParts) - { - var parts = mainPart.Split(':'); + var ids = invoice.PosData + .Split(',') + .Select(part => part.Split(':')) + .Where(parts => parts.Length == 2 && Guid.TryParse(parts[1], out _)) + .ToDictionary(parts => parts[0], parts => Guid.Parse(parts[1])); - if (parts.Length <= 1 || !Guid.TryParse(parts[1], out var id)) - { - continue; - } - - switch (parts[0]) - { - case "userId": - userId = id; - break; - case "organizationId": - orgId = id; - break; - case "providerId": - providerId = id; - break; - } - } - - return new Tuple(orgId, userId, providerId); + return new ValueTuple( + ids.TryGetValue(MetadataKeys.OrganizationId, out var id) ? id : null, + ids.TryGetValue(MetadataKeys.UserId, out id) ? id : null, + ids.TryGetValue(MetadataKeys.ProviderId, out id) ? id : null + ); } } diff --git a/src/Billing/Controllers/FreshdeskController.cs b/src/Billing/Controllers/FreshdeskController.cs deleted file mode 100644 index 66d4f47d92..0000000000 --- a/src/Billing/Controllers/FreshdeskController.cs +++ /dev/null @@ -1,383 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using System.Net.Http.Headers; -using System.Reflection; -using System.Text; -using System.Text.Json; -using System.Web; -using Bit.Billing.Models; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Bit.Core.Utilities; -using Markdig; -using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Options; - -namespace Bit.Billing.Controllers; - -[Route("freshdesk")] -public class FreshdeskController : Controller -{ - private readonly BillingSettings _billingSettings; - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - private readonly IHttpClientFactory _httpClientFactory; - - public FreshdeskController( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings, - IHttpClientFactory httpClientFactory) - { - _billingSettings = billingSettings?.Value; - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _logger = logger; - _globalSettings = globalSettings; - _httpClientFactory = httpClientFactory; - } - - [HttpPost("webhook")] - public async Task PostWebhook([FromQuery, Required] string key, - [FromBody, Required] FreshdeskWebhookModel model) - { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshDesk.WebhookKey)) - { - return new BadRequestResult(); - } - - try - { - var ticketId = model.TicketId; - var ticketContactEmail = model.TicketContactEmail; - var ticketTags = model.TicketTags; - if (string.IsNullOrWhiteSpace(ticketId) || string.IsNullOrWhiteSpace(ticketContactEmail)) - { - return new BadRequestResult(); - } - - var updateBody = new Dictionary(); - var note = string.Empty; - note += $"
  • Region: {_billingSettings.FreshDesk.Region}
  • "; - var customFields = new Dictionary(); - var user = await _userRepository.GetByEmailAsync(ticketContactEmail); - if (user == null) - { - note += $"
  • No user found: {ticketContactEmail}
  • "; - await CreateNote(ticketId, note); - } - - if (user != null) - { - var userLink = $"{_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}"; - note += $"
  • User, {user.Email}: {userLink}
  • "; - customFields.Add(_billingSettings.FreshDesk.UserFieldName, userLink); - var tags = new HashSet(); - if (user.Premium) - { - tags.Add("Premium"); - } - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - - foreach (var org in orgs) - { - // Prevent org names from injecting any additional HTML - var orgName = HttpUtility.HtmlEncode(org.Name); - var orgNote = $"{orgName} ({org.Seats.GetValueOrDefault()}): " + - $"{_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"; - note += $"
  • Org, {orgNote}
  • "; - if (!customFields.Any(kvp => kvp.Key == _billingSettings.FreshDesk.OrgFieldName)) - { - customFields.Add(_billingSettings.FreshDesk.OrgFieldName, orgNote); - } - else - { - customFields[_billingSettings.FreshDesk.OrgFieldName] += $"\n{orgNote}"; - } - - var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); - if (!string.IsNullOrWhiteSpace(planName)) - { - tags.Add(string.Format("Org: {0}", planName)); - } - } - if (tags.Any()) - { - var tagsToUpdate = tags.ToList(); - if (!string.IsNullOrWhiteSpace(ticketTags)) - { - var splitTicketTags = ticketTags.Split(','); - for (var i = 0; i < splitTicketTags.Length; i++) - { - tagsToUpdate.Insert(i, splitTicketTags[i]); - } - } - updateBody.Add("tags", tagsToUpdate); - } - - if (customFields.Any()) - { - updateBody.Add("custom_fields", customFields); - } - var updateRequest = new HttpRequestMessage(HttpMethod.Put, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}", ticketId)) - { - Content = JsonContent.Create(updateBody), - }; - await CallFreshdeskApiAsync(updateRequest); - await CreateNote(ticketId, note); - } - - return new OkResult(); - } - catch (Exception e) - { - _logger.LogError(e, "Error processing freshdesk webhook."); - return new BadRequestResult(); - } - } - - [HttpPost("webhook-onyx-ai")] - public async Task PostWebhookOnyxAi([FromQuery, Required] string key, - [FromBody, Required] FreshdeskOnyxAiWebhookModel model) - { - // ensure that the key is from Freshdesk - if (!IsValidRequestFromFreshdesk(key)) - { - return new BadRequestResult(); - } - - // if there is no description, then we don't send anything to onyx - if (string.IsNullOrEmpty(model.TicketDescriptionText.Trim())) - { - return Ok(); - } - - // create the onyx `answer-with-citation` request - var onyxRequestModel = new OnyxAnswerWithCitationRequestModel(model.TicketDescriptionText, _billingSettings.Onyx.PersonaId); - var onyxRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("{0}/query/answer-with-citation", _billingSettings.Onyx.BaseUrl)) - { - Content = JsonContent.Create(onyxRequestModel, mediaType: new MediaTypeHeaderValue("application/json")), - }; - var (_, onyxJsonResponse) = await CallOnyxApi(onyxRequest); - - // the CallOnyxApi will return a null if we have an error response - if (onyxJsonResponse?.Answer == null || !string.IsNullOrEmpty(onyxJsonResponse?.ErrorMsg)) - { - _logger.LogWarning("Error getting answer from Onyx AI. Freshdesk model: {model}\r\n Onyx query {query}\r\nresponse: {response}. ", - JsonSerializer.Serialize(model), - JsonSerializer.Serialize(onyxRequestModel), - JsonSerializer.Serialize(onyxJsonResponse)); - - return Ok(); // return ok so we don't retry - } - - // add the answer as a note to the ticket - await AddAnswerNoteToTicketAsync(onyxJsonResponse.Answer, model.TicketId); - - return Ok(); - } - - [HttpPost("webhook-onyx-ai-reply")] - public async Task PostWebhookOnyxAiReply([FromQuery, Required] string key, - [FromBody, Required] FreshdeskOnyxAiWebhookModel model) - { - // NOTE: - // at this time, this endpoint is a duplicate of `webhook-onyx-ai` - // eventually, we will merge both endpoints into one webhook for Freshdesk - - // ensure that the key is from Freshdesk - if (!IsValidRequestFromFreshdesk(key) || !ModelState.IsValid) - { - return new BadRequestResult(); - } - - // if there is no description, then we don't send anything to onyx - if (string.IsNullOrEmpty(model.TicketDescriptionText.Trim())) - { - return Ok(); - } - - // create the onyx `answer-with-citation` request - var onyxRequestModel = new OnyxAnswerWithCitationRequestModel(model.TicketDescriptionText, _billingSettings.Onyx.PersonaId); - var onyxRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("{0}/query/answer-with-citation", _billingSettings.Onyx.BaseUrl)) - { - Content = JsonContent.Create(onyxRequestModel, mediaType: new MediaTypeHeaderValue("application/json")), - }; - var (_, onyxJsonResponse) = await CallOnyxApi(onyxRequest); - - // the CallOnyxApi will return a null if we have an error response - if (onyxJsonResponse?.Answer == null || !string.IsNullOrEmpty(onyxJsonResponse?.ErrorMsg)) - { - _logger.LogWarning("Error getting answer from Onyx AI. Freshdesk model: {model}\r\n Onyx query {query}\r\nresponse: {response}. ", - JsonSerializer.Serialize(model), - JsonSerializer.Serialize(onyxRequestModel), - JsonSerializer.Serialize(onyxJsonResponse)); - - return Ok(); // return ok so we don't retry - } - - // add the reply to the ticket - await AddReplyToTicketAsync(onyxJsonResponse.Answer, model.TicketId); - - return Ok(); - } - - private bool IsValidRequestFromFreshdesk(string key) - { - if (string.IsNullOrWhiteSpace(key) - || !CoreHelpers.FixedTimeEquals(key, _billingSettings.FreshDesk.WebhookKey)) - { - return false; - } - - return true; - } - - private async Task CreateNote(string ticketId, string note) - { - var noteBody = new Dictionary - { - { "body", $"
      {note}
    " }, - { "private", true } - }; - var noteRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) - { - Content = JsonContent.Create(noteBody), - }; - await CallFreshdeskApiAsync(noteRequest); - } - - private async Task AddAnswerNoteToTicketAsync(string note, string ticketId) - { - // if there is no content, then we don't need to add a note - if (string.IsNullOrWhiteSpace(note)) - { - return; - } - - var noteBody = new Dictionary - { - { "body", $"Onyx AI:
      {note}
    " }, - { "private", true } - }; - - var noteRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/notes", ticketId)) - { - Content = JsonContent.Create(noteBody), - }; - - var addNoteResponse = await CallFreshdeskApiAsync(noteRequest); - if (addNoteResponse.StatusCode != System.Net.HttpStatusCode.Created) - { - _logger.LogError("Error adding note to Freshdesk ticket. Ticket Id: {0}. Status: {1}", - ticketId, addNoteResponse.ToString()); - } - } - - private async Task AddReplyToTicketAsync(string note, string ticketId) - { - // if there is no content, then we don't need to add a note - if (string.IsNullOrWhiteSpace(note)) - { - return; - } - - // convert note from markdown to html - var htmlNote = note; - try - { - var pipeline = new MarkdownPipelineBuilder().UseAdvancedExtensions().Build(); - htmlNote = Markdig.Markdown.ToHtml(note, pipeline); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error converting markdown to HTML for Freshdesk reply. Ticket Id: {0}. Note: {1}", - ticketId, note); - htmlNote = note; // fallback to the original note - } - - // clear out any new lines that Freshdesk doesn't like - if (_billingSettings.FreshDesk.RemoveNewlinesInReplies) - { - htmlNote = htmlNote.Replace(Environment.NewLine, string.Empty); - } - - var replyBody = new FreshdeskReplyRequestModel - { - Body = $"{_billingSettings.FreshDesk.AutoReplyGreeting}{htmlNote}{_billingSettings.FreshDesk.AutoReplySalutation}", - }; - - var replyRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("https://bitwarden.freshdesk.com/api/v2/tickets/{0}/reply", ticketId)) - { - Content = JsonContent.Create(replyBody), - }; - - var addReplyResponse = await CallFreshdeskApiAsync(replyRequest); - if (addReplyResponse.StatusCode != System.Net.HttpStatusCode.Created) - { - _logger.LogError("Error adding reply to Freshdesk ticket. Ticket Id: {0}. Status: {1}", - ticketId, addReplyResponse.ToString()); - } - } - - private async Task CallFreshdeskApiAsync(HttpRequestMessage request, int retriedCount = 0) - { - try - { - var freshdeskAuthkey = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{_billingSettings.FreshDesk.ApiKey}:X")); - var httpClient = _httpClientFactory.CreateClient("FreshdeskApi"); - request.Headers.Add("Authorization", $"Basic {freshdeskAuthkey}"); - var response = await httpClient.SendAsync(request); - if (response.StatusCode != System.Net.HttpStatusCode.TooManyRequests || retriedCount > 3) - { - return response; - } - } - catch - { - if (retriedCount > 3) - { - throw; - } - } - await Task.Delay(30000 * (retriedCount + 1)); - return await CallFreshdeskApiAsync(request, retriedCount++); - } - - private async Task<(HttpResponseMessage, T)> CallOnyxApi(HttpRequestMessage request) - { - var httpClient = _httpClientFactory.CreateClient("OnyxApi"); - var response = await httpClient.SendAsync(request); - - if (response.StatusCode != System.Net.HttpStatusCode.OK) - { - _logger.LogError("Error calling Onyx AI API. Status code: {0}. Response {1}", - response.StatusCode, JsonSerializer.Serialize(response)); - return (null, default); - } - var responseStr = await response.Content.ReadAsStringAsync(); - var responseJson = JsonSerializer.Deserialize(responseStr, options: new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true, - }); - - return (response, responseJson); - } - - private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute - { - return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); - } -} diff --git a/src/Billing/Controllers/FreshsalesController.cs b/src/Billing/Controllers/FreshsalesController.cs deleted file mode 100644 index be5a9ddb16..0000000000 --- a/src/Billing/Controllers/FreshsalesController.cs +++ /dev/null @@ -1,247 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Net.Http.Headers; -using System.Text.Json.Serialization; -using Bit.Core.Billing.Enums; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Options; - -namespace Bit.Billing.Controllers; - -[Route("freshsales")] -public class FreshsalesController : Controller -{ - private readonly IUserRepository _userRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly ILogger _logger; - private readonly GlobalSettings _globalSettings; - - private readonly string _freshsalesApiKey; - - private readonly HttpClient _httpClient; - - public FreshsalesController(IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOptions billingSettings, - ILogger logger, - GlobalSettings globalSettings) - { - _userRepository = userRepository; - _organizationRepository = organizationRepository; - _logger = logger; - _globalSettings = globalSettings; - - _httpClient = new HttpClient - { - BaseAddress = new Uri("https://bitwarden.freshsales.io/api/") - }; - - _freshsalesApiKey = billingSettings.Value.FreshsalesApiKey; - - _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue( - "Token", - $"token={_freshsalesApiKey}"); - } - - - [HttpPost("webhook")] - public async Task PostWebhook([FromHeader(Name = "Authorization")] string key, - [FromBody] CustomWebhookRequestModel request, - CancellationToken cancellationToken) - { - if (string.IsNullOrWhiteSpace(key) || !CoreHelpers.FixedTimeEquals(_freshsalesApiKey, key)) - { - return Unauthorized(); - } - - try - { - var leadResponse = await _httpClient.GetFromJsonAsync>( - $"leads/{request.LeadId}", - cancellationToken); - - var lead = leadResponse.Lead; - - var primaryEmail = lead.Emails - .Where(e => e.IsPrimary) - .FirstOrDefault(); - - if (primaryEmail == null) - { - return BadRequest(new { Message = "Lead has not primary email." }); - } - - var user = await _userRepository.GetByEmailAsync(primaryEmail.Value); - - if (user == null) - { - return NoContent(); - } - - var newTags = new HashSet(); - - if (user.Premium) - { - newTags.Add("Premium"); - } - - var noteItems = new List - { - $"User, {user.Email}: {_globalSettings.BaseServiceUri.Admin}/users/edit/{user.Id}" - }; - - var orgs = await _organizationRepository.GetManyByUserIdAsync(user.Id); - - foreach (var org in orgs) - { - noteItems.Add($"Org, {org.DisplayName()}: {_globalSettings.BaseServiceUri.Admin}/organizations/edit/{org.Id}"); - if (TryGetPlanName(org.PlanType, out var planName)) - { - newTags.Add($"Org: {planName}"); - } - } - - if (newTags.Any()) - { - var allTags = newTags.Concat(lead.Tags); - var updateLeadResponse = await _httpClient.PutAsJsonAsync( - $"leads/{request.LeadId}", - CreateWrapper(new { tags = allTags }), - cancellationToken); - updateLeadResponse.EnsureSuccessStatusCode(); - } - - var createNoteResponse = await _httpClient.PostAsJsonAsync( - "notes", - CreateNoteRequestModel(request.LeadId, string.Join('\n', noteItems)), cancellationToken); - createNoteResponse.EnsureSuccessStatusCode(); - return NoContent(); - } - catch (Exception ex) - { - Console.WriteLine(ex); - _logger.LogError(ex, "Error processing freshsales webhook"); - return BadRequest(new { ex.Message }); - } - } - - private static LeadWrapper CreateWrapper(T lead) - { - return new LeadWrapper - { - Lead = lead, - }; - } - - private static CreateNoteRequestModel CreateNoteRequestModel(long leadId, string content) - { - return new CreateNoteRequestModel - { - Note = new EditNoteModel - { - Description = content, - TargetableType = "Lead", - TargetableId = leadId, - }, - }; - } - - private static bool TryGetPlanName(PlanType planType, out string planName) - { - switch (planType) - { - case PlanType.Free: - planName = "Free"; - return true; - case PlanType.FamiliesAnnually: - case PlanType.FamiliesAnnually2019: - planName = "Families"; - return true; - case PlanType.TeamsAnnually: - case PlanType.TeamsAnnually2023: - case PlanType.TeamsAnnually2020: - case PlanType.TeamsAnnually2019: - case PlanType.TeamsMonthly: - case PlanType.TeamsMonthly2023: - case PlanType.TeamsMonthly2020: - case PlanType.TeamsMonthly2019: - case PlanType.TeamsStarter: - case PlanType.TeamsStarter2023: - planName = "Teams"; - return true; - case PlanType.EnterpriseAnnually: - case PlanType.EnterpriseAnnually2023: - case PlanType.EnterpriseAnnually2020: - case PlanType.EnterpriseAnnually2019: - case PlanType.EnterpriseMonthly: - case PlanType.EnterpriseMonthly2023: - case PlanType.EnterpriseMonthly2020: - case PlanType.EnterpriseMonthly2019: - planName = "Enterprise"; - return true; - case PlanType.Custom: - planName = "Custom"; - return true; - default: - planName = null; - return false; - } - } -} - -public class CustomWebhookRequestModel -{ - [JsonPropertyName("leadId")] - public long LeadId { get; set; } -} - -public class LeadWrapper -{ - [JsonPropertyName("lead")] - public T Lead { get; set; } - - public static LeadWrapper Create(TItem lead) - { - return new LeadWrapper - { - Lead = lead, - }; - } -} - -public class FreshsalesLeadModel -{ - public string[] Tags { get; set; } - public FreshsalesEmailModel[] Emails { get; set; } -} - -public class FreshsalesEmailModel -{ - [JsonPropertyName("value")] - public string Value { get; set; } - - [JsonPropertyName("is_primary")] - public bool IsPrimary { get; set; } -} - -public class CreateNoteRequestModel -{ - [JsonPropertyName("note")] - public EditNoteModel Note { get; set; } -} - -public class EditNoteModel -{ - [JsonPropertyName("description")] - public string Description { get; set; } - - [JsonPropertyName("targetable_type")] - public string TargetableType { get; set; } - - [JsonPropertyName("targetable_id")] - public long TargetableId { get; set; } -} diff --git a/src/Billing/Controllers/JobsController.cs b/src/Billing/Controllers/JobsController.cs new file mode 100644 index 0000000000..6a5e8e5531 --- /dev/null +++ b/src/Billing/Controllers/JobsController.cs @@ -0,0 +1,36 @@ +using Bit.Billing.Jobs; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Billing.Controllers; + +[Route("jobs")] +[SelfHosted(NotSelfHostedOnly = true)] +[RequireLowerEnvironment] +public class JobsController( + JobsHostedService jobsHostedService) : Controller +{ + [HttpPost("run/{jobName}")] + public async Task RunJobAsync(string jobName) + { + if (jobName == nameof(ReconcileAdditionalStorageJob)) + { + await jobsHostedService.RunJobAdHocAsync(); + return Ok(new { message = $"Job {jobName} scheduled successfully" }); + } + + return BadRequest(new { error = $"Unknown job name: {jobName}" }); + } + + [HttpPost("stop/{jobName}")] + public async Task StopJobAsync(string jobName) + { + if (jobName == nameof(ReconcileAdditionalStorageJob)) + { + await jobsHostedService.InterruptAdHocJobAsync(); + return Ok(new { message = $"Job {jobName} queued for cancellation" }); + } + + return BadRequest(new { error = $"Unknown job name: {jobName}" }); + } +} diff --git a/src/Billing/Controllers/PayPalController.cs b/src/Billing/Controllers/PayPalController.cs index 8039680fd5..70023b6bdb 100644 --- a/src/Billing/Controllers/PayPalController.cs +++ b/src/Billing/Controllers/PayPalController.cs @@ -23,7 +23,7 @@ public class PayPalController : Controller private readonly ILogger _logger; private readonly IMailService _mailService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ITransactionRepository _transactionRepository; private readonly IUserRepository _userRepository; private readonly IProviderRepository _providerRepository; @@ -34,7 +34,7 @@ public class PayPalController : Controller ILogger logger, IMailService mailService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ITransactionRepository transactionRepository, IUserRepository userRepository, IProviderRepository providerRepository, diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index b60e0c56e4..18f2198119 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -120,9 +120,7 @@ public class StripeController : Controller return deliveryContainer.ApiVersion switch { - "2024-06-20" => HandleVersionWith(_billingSettings.StripeWebhookSecret20240620), - "2023-10-16" => HandleVersionWith(_billingSettings.StripeWebhookSecret20231016), - "2022-08-01" => HandleVersionWith(_billingSettings.StripeWebhookSecret), + "2025-08-27.basil" => HandleVersionWith(_billingSettings.StripeWebhookSecret20250827Basil), _ => HandleDefault(deliveryContainer.ApiVersion) }; diff --git a/src/Billing/Jobs/AliveJob.cs b/src/Billing/Jobs/AliveJob.cs index 42f64099ac..1769cc94e2 100644 --- a/src/Billing/Jobs/AliveJob.cs +++ b/src/Billing/Jobs/AliveJob.cs @@ -10,4 +10,13 @@ public class AliveJob(ILogger logger) : BaseJob(logger) _logger.LogInformation(Core.Constants.BypassFiltersEventId, null, "Billing service is alive!"); return Task.FromResult(0); } + + public static ITrigger GetTrigger() + { + return TriggerBuilder.Create() + .WithIdentity("EveryTopOfTheHourTrigger") + .StartNow() + .WithCronSchedule("0 0 * * * ?") + .Build(); + } } diff --git a/src/Billing/Jobs/JobsHostedService.cs b/src/Billing/Jobs/JobsHostedService.cs index a6e702c662..25c57044da 100644 --- a/src/Billing/Jobs/JobsHostedService.cs +++ b/src/Billing/Jobs/JobsHostedService.cs @@ -1,29 +1,27 @@ -using Bit.Core.Jobs; +using Bit.Core.Exceptions; +using Bit.Core.Jobs; using Bit.Core.Settings; using Quartz; namespace Bit.Billing.Jobs; -public class JobsHostedService : BaseJobsHostedService +public class JobsHostedService( + GlobalSettings globalSettings, + IServiceProvider serviceProvider, + ILogger logger, + ILogger listenerLogger, + ISchedulerFactory schedulerFactory) + : BaseJobsHostedService(globalSettings, serviceProvider, logger, listenerLogger) { - public JobsHostedService( - GlobalSettings globalSettings, - IServiceProvider serviceProvider, - ILogger logger, - ILogger listenerLogger) - : base(globalSettings, serviceProvider, logger, listenerLogger) { } + private List AdHocJobKeys { get; } = []; + private IScheduler? _adHocScheduler; public override async Task StartAsync(CancellationToken cancellationToken) { - var everyTopOfTheHourTrigger = TriggerBuilder.Create() - .WithIdentity("EveryTopOfTheHourTrigger") - .StartNow() - .WithCronSchedule("0 0 * * * ?") - .Build(); - Jobs = new List> { - new Tuple(typeof(AliveJob), everyTopOfTheHourTrigger) + new(typeof(AliveJob), AliveJob.GetTrigger()), + new(typeof(ReconcileAdditionalStorageJob), ReconcileAdditionalStorageJob.GetTrigger()) }; await base.StartAsync(cancellationToken); @@ -33,5 +31,54 @@ public class JobsHostedService : BaseJobsHostedService { services.AddTransient(); services.AddTransient(); + services.AddTransient(); + // add this service as a singleton so we can inject it where needed + services.AddSingleton(); + services.AddHostedService(sp => sp.GetRequiredService()); + } + + public async Task InterruptAdHocJobAsync(CancellationToken cancellationToken = default) where T : class, IJob + { + if (_adHocScheduler == null) + { + throw new InvalidOperationException("AdHocScheduler is null, cannot interrupt ad-hoc job."); + } + + var jobKey = AdHocJobKeys.FirstOrDefault(j => j.Name == typeof(T).ToString()); + if (jobKey == null) + { + throw new NotFoundException($"Cannot find job key: {typeof(T)}, not running?"); + } + logger.LogInformation("CANCELLING ad-hoc job with key: {JobKey}", jobKey); + AdHocJobKeys.Remove(jobKey); + await _adHocScheduler.Interrupt(jobKey, cancellationToken); + } + + public async Task RunJobAdHocAsync(CancellationToken cancellationToken = default) where T : class, IJob + { + _adHocScheduler ??= await schedulerFactory.GetScheduler(cancellationToken); + + var jobKey = new JobKey(typeof(T).ToString()); + + var currentlyExecuting = await _adHocScheduler.GetCurrentlyExecutingJobs(cancellationToken); + if (currentlyExecuting.Any(j => j.JobDetail.Key.Equals(jobKey))) + { + throw new InvalidOperationException($"Job {jobKey} is already running"); + } + + AdHocJobKeys.Add(jobKey); + + var job = JobBuilder.Create() + .WithIdentity(jobKey) + .Build(); + + var trigger = TriggerBuilder.Create() + .WithIdentity(typeof(T).ToString()) + .StartNow() + .Build(); + + logger.LogInformation("Scheduling ad-hoc job with key: {JobKey}", jobKey); + + await _adHocScheduler.ScheduleJob(job, trigger, cancellationToken); } } diff --git a/src/Billing/Jobs/ProviderOrganizationDisableJob.cs b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs new file mode 100644 index 0000000000..5a48dd609f --- /dev/null +++ b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs @@ -0,0 +1,88 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Quartz; + +namespace Bit.Billing.Jobs; + +public class ProviderOrganizationDisableJob( + IProviderOrganizationRepository providerOrganizationRepository, + IOrganizationDisableCommand organizationDisableCommand, + ILogger logger) + : IJob +{ + private const int MaxConcurrency = 5; + private const int MaxTimeoutMinutes = 10; + + public async Task Execute(IJobExecutionContext context) + { + var providerId = new Guid(context.MergedJobDataMap.GetString("providerId") ?? string.Empty); + var expirationDateString = context.MergedJobDataMap.GetString("expirationDate"); + DateTime? expirationDate = string.IsNullOrEmpty(expirationDateString) + ? null + : DateTime.Parse(expirationDateString); + + logger.LogInformation("Starting to disable organizations for provider {ProviderId}", providerId); + + var startTime = DateTime.UtcNow; + var totalProcessed = 0; + var totalErrors = 0; + + try + { + var providerOrganizations = await providerOrganizationRepository + .GetManyDetailsByProviderAsync(providerId); + + if (providerOrganizations == null || !providerOrganizations.Any()) + { + logger.LogInformation("No organizations found for provider {ProviderId}", providerId); + return; + } + + logger.LogInformation("Disabling {OrganizationCount} organizations for provider {ProviderId}", + providerOrganizations.Count, providerId); + + var semaphore = new SemaphoreSlim(MaxConcurrency, MaxConcurrency); + var tasks = providerOrganizations.Select(async po => + { + if (DateTime.UtcNow.Subtract(startTime).TotalMinutes > MaxTimeoutMinutes) + { + logger.LogWarning("Timeout reached while disabling organizations for provider {ProviderId}", providerId); + return false; + } + + await semaphore.WaitAsync(); + try + { + await organizationDisableCommand.DisableAsync(po.OrganizationId, expirationDate); + Interlocked.Increment(ref totalProcessed); + return true; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to disable organization {OrganizationId} for provider {ProviderId}", + po.OrganizationId, providerId); + Interlocked.Increment(ref totalErrors); + return false; + } + finally + { + semaphore.Release(); + } + }); + + await Task.WhenAll(tasks); + + logger.LogInformation("Completed disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + } + catch (Exception ex) + { + logger.LogError(ex, "Error disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + throw; + } + } +} diff --git a/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs b/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs new file mode 100644 index 0000000000..312ed3122b --- /dev/null +++ b/src/Billing/Jobs/ReconcileAdditionalStorageJob.cs @@ -0,0 +1,193 @@ +using System.Globalization; +using System.Text.Json; +using Bit.Billing.Services; +using Bit.Core; +using Bit.Core.Billing.Constants; +using Bit.Core.Jobs; +using Bit.Core.Services; +using Quartz; +using Stripe; + +namespace Bit.Billing.Jobs; + +public class ReconcileAdditionalStorageJob( + IStripeFacade stripeFacade, + ILogger logger, + IFeatureService featureService) : BaseJob(logger) +{ + private const string _storageGbMonthlyPriceId = "storage-gb-monthly"; + private const string _storageGbAnnuallyPriceId = "storage-gb-annually"; + private const string _personalStorageGbAnnuallyPriceId = "personal-storage-gb-annually"; + private const int _storageGbToRemove = 4; + + protected override async Task ExecuteJobAsync(IJobExecutionContext context) + { + if (!featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob)) + { + logger.LogInformation("Skipping ReconcileAdditionalStorageJob, feature flag off."); + return; + } + + var liveMode = featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode); + + // Execution tracking + var subscriptionsFound = 0; + var subscriptionsUpdated = 0; + var subscriptionsWithErrors = 0; + var failures = new List(); + + logger.LogInformation("Starting ReconcileAdditionalStorageJob (live mode: {LiveMode})", liveMode); + + var priceIds = new[] { _storageGbMonthlyPriceId, _storageGbAnnuallyPriceId, _personalStorageGbAnnuallyPriceId }; + var stripeStatusesToProcess = new[] { StripeConstants.SubscriptionStatus.Active, StripeConstants.SubscriptionStatus.Trialing, StripeConstants.SubscriptionStatus.PastDue }; + + foreach (var priceId in priceIds) + { + var options = new SubscriptionListOptions { Limit = 100, Price = priceId }; + + await foreach (var subscription in stripeFacade.ListSubscriptionsAutoPagingAsync(options)) + { + if (context.CancellationToken.IsCancellationRequested) + { + logger.LogWarning( + "Job cancelled!! Exiting. Progress at time of cancellation: Subscriptions found: {SubscriptionsFound}, " + + "Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}", + subscriptionsFound, + liveMode + ? subscriptionsUpdated + : $"(In live mode, would have updated) {subscriptionsUpdated}", + subscriptionsWithErrors, + failures.Count > 0 + ? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}" + : string.Empty + ); + return; + } + + if (subscription == null) + { + continue; + } + + if (!stripeStatusesToProcess.Contains(subscription.Status)) + { + logger.LogInformation("Skipping subscription with unsupported status: {SubscriptionId} - {Status}", subscription.Id, subscription.Status); + continue; + } + + logger.LogInformation("Processing subscription: {SubscriptionId}", subscription.Id); + subscriptionsFound++; + + if (subscription.Metadata?.TryGetValue(StripeConstants.MetadataKeys.StorageReconciled2025, out var dateString) == true) + { + if (DateTime.TryParse(dateString, null, DateTimeStyles.RoundtripKind, out var dateProcessed)) + { + logger.LogInformation("Skipping subscription {SubscriptionId} - already processed on {Date}", + subscription.Id, + dateProcessed.ToString("f")); + continue; + } + } + + var updateOptions = BuildSubscriptionUpdateOptions(subscription, priceId); + + if (updateOptions == null) + { + logger.LogInformation("Skipping subscription {SubscriptionId} - no updates needed", subscription.Id); + continue; + } + + subscriptionsUpdated++; + + if (!liveMode) + { + logger.LogInformation( + "Not live mode (dry-run): Would have updated subscription {SubscriptionId} with item changes: {NewLine}{UpdateOptions}", + subscription.Id, + Environment.NewLine, + JsonSerializer.Serialize(updateOptions)); + continue; + } + + try + { + await stripeFacade.UpdateSubscription(subscription.Id, updateOptions); + logger.LogInformation("Successfully updated subscription: {SubscriptionId}", subscription.Id); + } + catch (Exception ex) + { + subscriptionsWithErrors++; + failures.Add($"Subscription {subscription.Id}: {ex.Message}"); + logger.LogError(ex, "Failed to update subscription {SubscriptionId}: {ErrorMessage}", + subscription.Id, ex.Message); + } + } + } + + logger.LogInformation( + "ReconcileAdditionalStorageJob completed. Subscriptions found: {SubscriptionsFound}, " + + "Updated: {SubscriptionsUpdated}, Errors: {SubscriptionsWithErrors}{Failures}", + subscriptionsFound, + liveMode + ? subscriptionsUpdated + : $"(In live mode, would have updated) {subscriptionsUpdated}", + subscriptionsWithErrors, + failures.Count > 0 + ? $", Failures: {Environment.NewLine}{string.Join(Environment.NewLine, failures)}" + : string.Empty + ); + } + + private SubscriptionUpdateOptions? BuildSubscriptionUpdateOptions( + Subscription subscription, + string targetPriceId) + { + if (subscription.Items?.Data == null) + { + return null; + } + + var updateOptions = new SubscriptionUpdateOptions { ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, Metadata = new Dictionary { [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") }, Items = [] }; + + var hasUpdates = false; + + foreach (var item in subscription.Items.Data.Where(item => item?.Price?.Id == targetPriceId)) + { + hasUpdates = true; + var currentQuantity = item.Quantity; + + if (currentQuantity > _storageGbToRemove) + { + var newQuantity = currentQuantity - _storageGbToRemove; + logger.LogInformation( + "Subscription {SubscriptionId}: reducing quantity from {CurrentQuantity} to {NewQuantity} for price {PriceId}", + subscription.Id, + currentQuantity, + newQuantity, + item.Price.Id); + + updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Quantity = newQuantity }); + } + else + { + logger.LogInformation("Subscription {SubscriptionId}: deleting storage item with quantity {CurrentQuantity} for price {PriceId}", + subscription.Id, + currentQuantity, + item.Price.Id); + + updateOptions.Items.Add(new SubscriptionItemOptions { Id = item.Id, Deleted = true }); + } + } + + return hasUpdates ? updateOptions : null; + } + + public static ITrigger GetTrigger() + { + return TriggerBuilder.Create() + .WithIdentity("EveryMorningTrigger") + .StartNow() + .WithCronSchedule("0 0 16 * * ?") // 10am CST daily; the pods execute in UTC time + .Build(); + } +} diff --git a/src/Billing/Jobs/SubscriptionCancellationJob.cs b/src/Billing/Jobs/SubscriptionCancellationJob.cs index 69b7bc876d..60b671df3d 100644 --- a/src/Billing/Jobs/SubscriptionCancellationJob.cs +++ b/src/Billing/Jobs/SubscriptionCancellationJob.cs @@ -1,16 +1,17 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Billing.Services; +using Bit.Billing.Services; +using Bit.Core.Billing.Constants; using Bit.Core.Repositories; using Quartz; using Stripe; namespace Bit.Billing.Jobs; +using static StripeConstants; + public class SubscriptionCancellationJob( IStripeFacade stripeFacade, - IOrganizationRepository organizationRepository) + IOrganizationRepository organizationRepository, + ILogger logger) : IJob { public async Task Execute(IJobExecutionContext context) @@ -21,20 +22,31 @@ public class SubscriptionCancellationJob( var organization = await organizationRepository.GetByIdAsync(organizationId); if (organization == null || organization.Enabled) { + logger.LogWarning("{Job} skipped for subscription ({SubscriptionID}) because organization is either null or enabled", nameof(SubscriptionCancellationJob), subscriptionId); // Organization was deleted or re-enabled by CS, skip cancellation return; } - var subscription = await stripeFacade.GetSubscription(subscriptionId); - if (subscription?.Status != "unpaid" || - subscription.LatestInvoice?.BillingReason is not ("subscription_cycle" or "subscription_create")) + var subscription = await stripeFacade.GetSubscription(subscriptionId, new SubscriptionGetOptions { + Expand = ["latest_invoice"] + }); + + if (subscription is not + { + Status: SubscriptionStatus.Unpaid, + LatestInvoice: { BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle } + }) + { + logger.LogWarning("{Job} skipped for subscription ({SubscriptionID}) because subscription is not unpaid or does not have a cancellable billing reason", nameof(SubscriptionCancellationJob), subscriptionId); return; } // Cancel the subscription await stripeFacade.CancelSubscription(subscriptionId, new SubscriptionCancelOptions()); + logger.LogInformation("{Job} cancelled subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), subscriptionId); + // Void any open invoices var options = new InvoiceListOptions { @@ -46,6 +58,7 @@ public class SubscriptionCancellationJob( foreach (var invoice in invoices) { await stripeFacade.VoidInvoice(invoice.Id); + logger.LogInformation("{Job} voided invoice ({InvoiceID}) for subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), invoice.Id, subscriptionId); } while (invoices.HasMore) @@ -55,6 +68,7 @@ public class SubscriptionCancellationJob( foreach (var invoice in invoices) { await stripeFacade.VoidInvoice(invoice.Id); + logger.LogInformation("{Job} voided invoice ({InvoiceID}) for subscription ({SubscriptionID})", nameof(SubscriptionCancellationJob), invoice.Id, subscriptionId); } } } diff --git a/src/Billing/Models/FreshdeskReplyRequestModel.cs b/src/Billing/Models/FreshdeskReplyRequestModel.cs deleted file mode 100644 index 3927039769..0000000000 --- a/src/Billing/Models/FreshdeskReplyRequestModel.cs +++ /dev/null @@ -1,9 +0,0 @@ -using System.Text.Json.Serialization; - -namespace Bit.Billing.Models; - -public class FreshdeskReplyRequestModel -{ - [JsonPropertyName("body")] - public required string Body { get; set; } -} diff --git a/src/Billing/Models/FreshdeskWebhookModel.cs b/src/Billing/Models/FreshdeskWebhookModel.cs deleted file mode 100644 index aac0e9339d..0000000000 --- a/src/Billing/Models/FreshdeskWebhookModel.cs +++ /dev/null @@ -1,24 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; - -namespace Bit.Billing.Models; - -public class FreshdeskWebhookModel -{ - [JsonPropertyName("ticket_id")] - public string TicketId { get; set; } - - [JsonPropertyName("ticket_contact_email")] - public string TicketContactEmail { get; set; } - - [JsonPropertyName("ticket_tags")] - public string TicketTags { get; set; } -} - -public class FreshdeskOnyxAiWebhookModel : FreshdeskWebhookModel -{ - [JsonPropertyName("ticket_description_text")] - public string TicketDescriptionText { get; set; } -} diff --git a/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs b/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs deleted file mode 100644 index ba3b89e297..0000000000 --- a/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs +++ /dev/null @@ -1,52 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - - -using System.Text.Json.Serialization; - -namespace Bit.Billing.Models; - -public class OnyxAnswerWithCitationRequestModel -{ - [JsonPropertyName("messages")] - public List Messages { get; set; } - - [JsonPropertyName("persona_id")] - public int PersonaId { get; set; } = 1; - - [JsonPropertyName("retrieval_options")] - public RetrievalOptions RetrievalOptions { get; set; } - - public OnyxAnswerWithCitationRequestModel(string message, int personaId = 1) - { - message = message.Replace(Environment.NewLine, " ").Replace('\r', ' ').Replace('\n', ' '); - Messages = new List() { new Message() { MessageText = message } }; - RetrievalOptions = new RetrievalOptions(); - PersonaId = personaId; - } -} - -public class Message -{ - [JsonPropertyName("message")] - public string MessageText { get; set; } - - [JsonPropertyName("sender")] - public string Sender { get; set; } = "user"; -} - -public class RetrievalOptions -{ - [JsonPropertyName("run_search")] - public string RunSearch { get; set; } = RetrievalOptionsRunSearch.Auto; - - [JsonPropertyName("real_time")] - public bool RealTime { get; set; } = true; -} - -public class RetrievalOptionsRunSearch -{ - public const string Always = "always"; - public const string Never = "never"; - public const string Auto = "auto"; -} diff --git a/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs b/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs deleted file mode 100644 index 5f67cd51d2..0000000000 --- a/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs +++ /dev/null @@ -1,33 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; - -namespace Bit.Billing.Models; - -public class OnyxAnswerWithCitationResponseModel -{ - [JsonPropertyName("answer")] - public string Answer { get; set; } - - [JsonPropertyName("rephrase")] - public string Rephrase { get; set; } - - [JsonPropertyName("citations")] - public List Citations { get; set; } - - [JsonPropertyName("llm_selected_doc_indices")] - public List LlmSelectedDocIndices { get; set; } - - [JsonPropertyName("error_msg")] - public string ErrorMsg { get; set; } -} - -public class Citation -{ - [JsonPropertyName("citation_num")] - public int CitationNum { get; set; } - - [JsonPropertyName("document_id")] - public string DocumentId { get; set; } -} diff --git a/src/Billing/Program.cs b/src/Billing/Program.cs index 3e005ce7fd..334dc49368 100644 --- a/src/Billing/Program.cs +++ b/src/Billing/Program.cs @@ -8,28 +8,12 @@ public class Program { Host .CreateDefaultBuilder(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.StartsWith("\"Bit.Billing.Jobs") || context.StartsWith("\"Bit.Core.Jobs")) - { - return e.Level >= globalSettings.MinLogLevel.BillingSettings.Jobs; - } - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.BillingSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Billing/Services/IStripeEventUtilityService.cs b/src/Billing/Services/IStripeEventUtilityService.cs index a5f536ad11..058f56c887 100644 --- a/src/Billing/Services/IStripeEventUtilityService.cs +++ b/src/Billing/Services/IStripeEventUtilityService.cs @@ -36,7 +36,7 @@ public interface IStripeEventUtilityService /// /// /// /// - Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId); + Task FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId); /// /// Attempts to pay the specified invoice. If a customer is eligible, the invoice is paid using Braintree or Stripe. diff --git a/src/Billing/Services/IStripeFacade.cs b/src/Billing/Services/IStripeFacade.cs index 280a3aca3c..c7073b9cf9 100644 --- a/src/Billing/Services/IStripeFacade.cs +++ b/src/Billing/Services/IStripeFacade.cs @@ -20,6 +20,12 @@ public interface IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + IAsyncEnumerable GetCustomerCashBalanceTransactions( + string customerId, + CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task UpdateCustomer( string customerId, CustomerUpdateOptions customerUpdateOptions = null, @@ -78,6 +84,11 @@ public interface IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + IAsyncEnumerable ListSubscriptionsAutoPagingAsync( + SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); + Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -111,4 +122,10 @@ public interface IStripeFacade TestClockGetOptions testClockGetOptions = null, RequestOptions requestOptions = null, CancellationToken cancellationToken = default); + + Task GetCoupon( + string couponId, + CouponGetOptions couponGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default); } diff --git a/src/Billing/Services/Implementations/ChargeRefundedHandler.cs b/src/Billing/Services/Implementations/ChargeRefundedHandler.cs index 905491b6c5..8cc3cb2ce6 100644 --- a/src/Billing/Services/Implementations/ChargeRefundedHandler.cs +++ b/src/Billing/Services/Implementations/ChargeRefundedHandler.cs @@ -38,7 +38,7 @@ public class ChargeRefundedHandler : IChargeRefundedHandler { // Attempt to create a transaction for the charge if it doesn't exist var (organizationId, userId, providerId) = await _stripeEventUtilityService.GetEntityIdsFromChargeAsync(charge); - var tx = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId); + var tx = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId); try { parentTransaction = await _transactionRepository.CreateAsync(tx); diff --git a/src/Billing/Services/Implementations/ChargeSucceededHandler.cs b/src/Billing/Services/Implementations/ChargeSucceededHandler.cs index bd8ea7def2..20c4dcfa98 100644 --- a/src/Billing/Services/Implementations/ChargeSucceededHandler.cs +++ b/src/Billing/Services/Implementations/ChargeSucceededHandler.cs @@ -46,7 +46,7 @@ public class ChargeSucceededHandler : IChargeSucceededHandler return; } - var transaction = _stripeEventUtilityService.FromChargeToTransaction(charge, organizationId, userId, providerId); + var transaction = await _stripeEventUtilityService.FromChargeToTransactionAsync(charge, organizationId, userId, providerId); if (!transaction.PaymentMethodType.HasValue) { _logger.LogWarning("Charge success from unsupported source/method. {ChargeId}", charge.Id); diff --git a/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs b/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs index 5bb098bec5..101b0e26b9 100644 --- a/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs +++ b/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs @@ -1,4 +1,5 @@ -using Event = Stripe.Event; +using Bit.Core.Billing.Constants; +using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -35,13 +36,13 @@ public class InvoiceCreatedHandler( if (usingPayPal && invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_create" or "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null and not "" + Parent.SubscriptionDetails: not null }) { await stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); diff --git a/src/Billing/Services/Implementations/PaymentFailedHandler.cs b/src/Billing/Services/Implementations/PaymentFailedHandler.cs index acf6ca70c7..0da6d03e94 100644 --- a/src/Billing/Services/Implementations/PaymentFailedHandler.cs +++ b/src/Billing/Services/Implementations/PaymentFailedHandler.cs @@ -1,4 +1,5 @@ -using Stripe; +using Bit.Core.Billing.Constants; +using Stripe; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -26,17 +27,20 @@ public class PaymentFailedHandler : IPaymentFailedHandler public async Task HandleAsync(Event parsedEvent) { var invoice = await _stripeEventService.GetInvoice(parsedEvent, true); - if (invoice.Paid || invoice.AttemptCount <= 1 || !ShouldAttemptToPayInvoice(invoice)) + if (invoice.Status == StripeConstants.InvoiceStatus.Paid || invoice.AttemptCount <= 1 || !ShouldAttemptToPayInvoice(invoice)) { return; } - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); - // attempt count 4 = 11 days after initial failure - if (invoice.AttemptCount <= 3 || - !subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + if (invoice.Parent?.SubscriptionDetails != null) { - await _stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); + var subscription = await _stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); + // attempt count 4 = 11 days after initial failure + if (invoice.AttemptCount <= 3 || + !subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + { + await _stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); + } } } @@ -44,9 +48,9 @@ public class PaymentFailedHandler : IPaymentFailedHandler invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null + Parent.SubscriptionDetails: not null }; } diff --git a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs index a10fa4b3d6..443227f7bf 100644 --- a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs @@ -1,7 +1,9 @@ using Bit.Billing.Constants; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Repositories; using Bit.Core.Services; @@ -29,12 +31,17 @@ public class PaymentSucceededHandler( public async Task HandleAsync(Event parsedEvent) { var invoice = await stripeEventService.GetInvoice(parsedEvent, true); - if (!invoice.Paid || invoice.BillingReason != "subscription_create") + if (invoice.Status != StripeConstants.InvoiceStatus.Paid || invoice.BillingReason != "subscription_create") { return; } - var subscription = await stripeFacade.GetSubscription(invoice.SubscriptionId); + if (invoice.Parent?.SubscriptionDetails == null) + { + return; + } + + var subscription = await stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); if (subscription?.Status != StripeSubscriptionStatus.Active) { return; @@ -96,7 +103,7 @@ public class PaymentSucceededHandler( return; } - await organizationEnableCommand.EnableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await organizationEnableCommand.EnableAsync(organizationId.Value, subscription.GetCurrentPeriodEnd()); organization = await organizationRepository.GetByIdAsync(organization.Id); await pushNotificationAdapter.NotifyEnabledChangedAsync(organization!); } @@ -107,7 +114,7 @@ public class PaymentSucceededHandler( return; } - await userService.EnablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await userService.EnablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); } } } diff --git a/src/Billing/Services/Implementations/ProviderEventService.cs b/src/Billing/Services/Implementations/ProviderEventService.cs index 12716c5aa2..79c85cb48f 100644 --- a/src/Billing/Services/Implementations/ProviderEventService.cs +++ b/src/Billing/Services/Implementations/ProviderEventService.cs @@ -28,9 +28,14 @@ public class ProviderEventService( return; } - var invoice = await stripeEventService.GetInvoice(parsedEvent); + var invoice = await stripeEventService.GetInvoice(parsedEvent, true, ["discounts"]); - var metadata = (await stripeFacade.GetSubscription(invoice.SubscriptionId)).Metadata ?? new Dictionary(); + if (invoice.Parent is not { Type: "subscription_details" }) + { + return; + } + + var metadata = (await stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId)).Metadata ?? new Dictionary(); var hasProviderId = metadata.TryGetValue("providerId", out var providerId); @@ -68,7 +73,9 @@ public class ProviderEventService( var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - var discountedPercentage = (100 - (invoice.Discount?.Coupon?.PercentOff ?? 0)) / 100; + var totalPercentOff = invoice.Discounts?.Sum(discount => discount?.Coupon?.PercentOff ?? 0) ?? 0; + + var discountedPercentage = (100 - totalPercentOff) / 100; var discountedSeatPrice = plan.PasswordManager.ProviderPortalSeatPrice * discountedPercentage; @@ -96,7 +103,9 @@ public class ProviderEventService( var unassignedSeats = providerPlan.SeatMinimum - clientSeats ?? 0; - var discountedPercentage = (100 - (invoice.Discount?.Coupon?.PercentOff ?? 0)) / 100; + var totalPercentOff = invoice.Discounts?.Sum(discount => discount?.Coupon?.PercentOff ?? 0) ?? 0; + + var discountedPercentage = (100 - totalPercentOff) / 100; var discountedSeatPrice = plan.PasswordManager.ProviderPortalSeatPrice * discountedPercentage; diff --git a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs index bc3fa1bd56..89e40f0e43 100644 --- a/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/SetupIntentSucceededHandler.cs @@ -2,8 +2,8 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using OneOf; using Stripe; using Event = Stripe.Event; @@ -59,10 +59,10 @@ public class SetupIntentSucceededHandler( return; } - await stripeAdapter.PaymentMethodAttachAsync(paymentMethod.Id, + await stripeAdapter.AttachPaymentMethodAsync(paymentMethod.Id, new PaymentMethodAttachOptions { Customer = customerId }); - await stripeAdapter.CustomerUpdateAsync(customerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { diff --git a/src/Billing/Services/Implementations/StripeEventUtilityService.cs b/src/Billing/Services/Implementations/StripeEventUtilityService.cs index 4c96bf977d..53512427c0 100644 --- a/src/Billing/Services/Implementations/StripeEventUtilityService.cs +++ b/src/Billing/Services/Implementations/StripeEventUtilityService.cs @@ -2,12 +2,13 @@ #nullable disable using Bit.Billing.Constants; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Models; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; using Braintree; using Stripe; using Customer = Stripe.Customer; @@ -87,25 +88,6 @@ public class StripeEventUtilityService : IStripeEventUtilityService /// public async Task<(Guid?, Guid?, Guid?)> GetEntityIdsFromChargeAsync(Charge charge) { - Guid? organizationId = null; - Guid? userId = null; - Guid? providerId = null; - - if (charge.InvoiceId != null) - { - var invoice = await _stripeFacade.GetInvoice(charge.InvoiceId); - if (invoice?.SubscriptionId != null) - { - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); - (organizationId, userId, providerId) = GetIdsFromMetadata(subscription?.Metadata); - } - } - - if (organizationId.HasValue || userId.HasValue || providerId.HasValue) - { - return (organizationId, userId, providerId); - } - var subscriptions = await _stripeFacade.ListSubscriptions(new SubscriptionListOptions { Customer = charge.CustomerId @@ -118,7 +100,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService continue; } - (organizationId, userId, providerId) = GetIdsFromMetadata(subscription.Metadata); + var (organizationId, userId, providerId) = GetIdsFromMetadata(subscription.Metadata); if (organizationId.HasValue || userId.HasValue || providerId.HasValue) { @@ -130,7 +112,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService } public bool IsSponsoredSubscription(Subscription subscription) => - StaticStore.SponsoredPlans + SponsoredPlans.All .Any(p => subscription.Items .Any(i => i.Plan.Id == p.StripePlanId)); @@ -142,7 +124,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService /// /// /// /// - public Transaction FromChargeToTransaction(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId) + public async Task FromChargeToTransactionAsync(Charge charge, Guid? organizationId, Guid? userId, Guid? providerId) { var transaction = new Transaction { @@ -227,6 +209,24 @@ public class StripeEventUtilityService : IStripeEventUtilityService transaction.PaymentMethodType = PaymentMethodType.BankAccount; transaction.Details = $"ACH => {achCreditTransfer.BankName}, {achCreditTransfer.AccountNumber}"; } + else if (charge.PaymentMethodDetails.CustomerBalance != null) + { + var bankTransferType = await GetFundingBankTransferTypeAsync(charge); + + if (!string.IsNullOrEmpty(bankTransferType)) + { + transaction.PaymentMethodType = PaymentMethodType.BankAccount; + transaction.Details = bankTransferType switch + { + "eu_bank_transfer" => "EU Bank Transfer", + "gb_bank_transfer" => "GB Bank Transfer", + "jp_bank_transfer" => "JP Bank Transfer", + "mx_bank_transfer" => "MX Bank Transfer", + "us_bank_transfer" => "US Bank Transfer", + _ => "Bank Transfer" + }; + } + } break; } @@ -256,10 +256,10 @@ public class StripeEventUtilityService : IStripeEventUtilityService invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null + Parent.SubscriptionDetails: not null }; private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) @@ -272,7 +272,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService return false; } - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); + if (invoice.Parent?.SubscriptionDetails == null) + { + _logger.LogWarning("Invoice parent was not a subscription."); + return false; + } + + var subscription = await _stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); var (organizationId, userId, providerId) = GetIdsFromMetadata(subscription?.Metadata); if (!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) { @@ -301,20 +307,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService } var btInvoiceAmount = Math.Round(invoice.AmountDue / 100M, 2); - var existingTransactions = organizationId.HasValue - ? await _transactionRepository.GetManyByOrganizationIdAsync(organizationId.Value) - : userId.HasValue - ? await _transactionRepository.GetManyByUserIdAsync(userId.Value) - : await _transactionRepository.GetManyByProviderIdAsync(providerId.Value); - - var duplicateTimeSpan = TimeSpan.FromHours(24); - var now = DateTime.UtcNow; - var duplicateTransaction = existingTransactions? - .FirstOrDefault(t => (now - t.CreationDate) < duplicateTimeSpan); - if (duplicateTransaction != null) + // Check if this invoice already has a Braintree transaction ID to prevent duplicate charges + if (invoice.Metadata?.ContainsKey("btTransactionId") ?? false) { - _logger.LogWarning("There is already a recent PayPal transaction ({0}). " + - "Do not charge again to prevent possible duplicate.", duplicateTransaction.GatewayId); + _logger.LogWarning("Invoice {InvoiceId} already has a Braintree transaction ({TransactionId}). " + + "Do not charge again to prevent duplicate.", + invoice.Id, + invoice.Metadata["btTransactionId"]); return false; } @@ -425,4 +424,55 @@ public class StripeEventUtilityService : IStripeEventUtilityService throw; } } + + /// + /// Retrieves the bank transfer type that funded a charge paid via customer balance. + /// + /// The charge to analyze. + /// + /// The bank transfer type (e.g., "us_bank_transfer", "eu_bank_transfer") if the charge was funded + /// by a bank transfer via customer balance, otherwise null. + /// + private async Task GetFundingBankTransferTypeAsync(Charge charge) + { + if (charge is not + { + CustomerId: not null, + PaymentIntentId: not null, + PaymentMethodDetails: { Type: "customer_balance" } + }) + { + return null; + } + + var cashBalanceTransactions = _stripeFacade.GetCustomerCashBalanceTransactions(charge.CustomerId); + + string bankTransferType = null; + var matchingPaymentIntentFound = false; + + await foreach (var cashBalanceTransaction in cashBalanceTransactions) + { + switch (cashBalanceTransaction) + { + case { Type: "funded", Funded: not null }: + { + bankTransferType = cashBalanceTransaction.Funded.BankTransfer.Type; + break; + } + case { Type: "applied_to_payment", AppliedToPayment: not null } + when cashBalanceTransaction.AppliedToPayment.PaymentIntentId == charge.PaymentIntentId: + { + matchingPaymentIntentFound = true; + break; + } + } + + if (matchingPaymentIntentFound && !string.IsNullOrEmpty(bankTransferType)) + { + return bankTransferType; + } + } + + return null; + } } diff --git a/src/Billing/Services/Implementations/StripeFacade.cs b/src/Billing/Services/Implementations/StripeFacade.cs index eef7ce009e..49cde981cd 100644 --- a/src/Billing/Services/Implementations/StripeFacade.cs +++ b/src/Billing/Services/Implementations/StripeFacade.cs @@ -11,6 +11,7 @@ public class StripeFacade : IStripeFacade { private readonly ChargeService _chargeService = new(); private readonly CustomerService _customerService = new(); + private readonly CustomerCashBalanceTransactionService _customerCashBalanceTransactionService = new(); private readonly EventService _eventService = new(); private readonly InvoiceService _invoiceService = new(); private readonly PaymentMethodService _paymentMethodService = new(); @@ -18,6 +19,7 @@ public class StripeFacade : IStripeFacade private readonly DiscountService _discountService = new(); private readonly SetupIntentService _setupIntentService = new(); private readonly TestClockService _testClockService = new(); + private readonly CouponService _couponService = new(); public async Task GetCharge( string chargeId, @@ -40,6 +42,13 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _customerService.GetAsync(customerId, customerGetOptions, requestOptions, cancellationToken); + public IAsyncEnumerable GetCustomerCashBalanceTransactions( + string customerId, + CustomerCashBalanceTransactionListOptions customerCashBalanceTransactionListOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) + => _customerCashBalanceTransactionService.ListAutoPagingAsync(customerId, customerCashBalanceTransactionListOptions, requestOptions, cancellationToken); + public async Task UpdateCustomer( string customerId, CustomerUpdateOptions customerUpdateOptions = null, @@ -98,6 +107,12 @@ public class StripeFacade : IStripeFacade CancellationToken cancellationToken = default) => await _subscriptionService.ListAsync(options, requestOptions, cancellationToken); + public IAsyncEnumerable ListSubscriptionsAutoPagingAsync( + SubscriptionListOptions options = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + _subscriptionService.ListAutoPagingAsync(options, requestOptions, cancellationToken); + public async Task GetSubscription( string subscriptionId, SubscriptionGetOptions subscriptionGetOptions = null, @@ -137,4 +152,11 @@ public class StripeFacade : IStripeFacade RequestOptions requestOptions = null, CancellationToken cancellationToken = default) => _testClockService.GetAsync(testClockId, testClockGetOptions, requestOptions, cancellationToken); + + public Task GetCoupon( + string couponId, + CouponGetOptions couponGetOptions = null, + RequestOptions requestOptions = null, + CancellationToken cancellationToken = default) => + _couponService.GetAsync(couponId, couponGetOptions, requestOptions, cancellationToken); } diff --git a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs index 465da86c3f..c204cc5026 100644 --- a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs @@ -1,6 +1,11 @@ using Bit.Billing.Constants; +using Bit.Billing.Jobs; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Extensions; using Bit.Core.Services; +using Quartz; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -10,17 +15,26 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler private readonly IUserService _userService; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ISchedulerFactory _schedulerFactory; public SubscriptionDeletedHandler( IStripeEventService stripeEventService, IUserService userService, IStripeEventUtilityService stripeEventUtilityService, - IOrganizationDisableCommand organizationDisableCommand) + IOrganizationDisableCommand organizationDisableCommand, + IProviderRepository providerRepository, + IProviderService providerService, + ISchedulerFactory schedulerFactory) { _stripeEventService = stripeEventService; _userService = userService; _stripeEventUtilityService = stripeEventUtilityService; _organizationDisableCommand = organizationDisableCommand; + _providerRepository = providerRepository; + _providerService = providerService; + _schedulerFactory = schedulerFactory; } /// @@ -50,11 +64,40 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler return; } - await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.GetCurrentPeriodEnd()); + } + else if (providerId.HasValue) + { + var provider = await _providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = false; + await _providerService.UpdateAsync(provider); + + await QueueProviderOrganizationDisableJobAsync(providerId.Value, subscription.GetCurrentPeriodEnd()); + } } else if (userId.HasValue) { - await _userService.DisablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.DisablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); } } + + private async Task QueueProviderOrganizationDisableJobAsync(Guid providerId, DateTime? expirationDate) + { + var scheduler = await _schedulerFactory.GetScheduler(); + + var job = JobBuilder.Create() + .WithIdentity($"disable-provider-orgs-{providerId}", "provider-management") + .UsingJobData("providerId", providerId.ToString()) + .UsingJobData("expirationDate", expirationDate?.ToString("O")) + .Build(); + + var trigger = TriggerBuilder.Create() + .WithIdentity($"disable-trigger-{providerId}", "provider-management") + .StartNow() + .Build(); + + await scheduler.ScheduleJob(job, trigger); + } } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 10630f78f4..c10368d8c0 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,10 +1,10 @@ -using System.Globalization; -using Bit.Billing.Constants; +using Bit.Billing.Constants; using Bit.Billing.Jobs; -using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -82,12 +82,14 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice", "test_clock"]); var (organizationId, userId, providerId) = _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); + var currentPeriodEnd = subscription.GetCurrentPeriodEnd(); + switch (subscription.Status) { case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired when organizationId.HasValue: { - await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); if (subscription.Status == StripeSubscriptionStatus.Unpaid && subscription.LatestInvoice is { BillingReason: "subscription_cycle" or "subscription_create" }) { @@ -107,14 +109,27 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler break; } - if (subscription.Status is StripeSubscriptionStatus.Unpaid && - subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + if (await IsPremiumSubscriptionAsync(subscription)) { await CancelSubscription(subscription.Id); await VoidOpenInvoices(subscription.Id); } - await _userService.DisablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + + break; + } + case StripeSubscriptionStatus.Incomplete when userId.HasValue: + { + // Handle Incomplete subscriptions for Premium users that have open invoices from failed payments + // This prevents duplicate subscriptions when users retry the subscription flow + if (await IsPremiumSubscriptionAsync(subscription) && + subscription.LatestInvoice is { Status: StripeInvoiceStatus.Open }) + { + await CancelSubscription(subscription.Id); + await VoidOpenInvoices(subscription.Id); + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); + } break; } @@ -130,11 +145,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } case StripeSubscriptionStatus.Active when providerId.HasValue: { - var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); - if (!providerPortalTakeover) - { - break; - } var provider = await _providerRepository.GetByIdAsync(providerId.Value); if (provider != null) { @@ -154,7 +164,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler { if (userId.HasValue) { - await _userService.EnablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd); } break; } @@ -162,17 +172,17 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler if (organizationId.HasValue) { - await _organizationService.UpdateExpirationDateAsync(organizationId.Value, subscription.CurrentPeriodEnd); - if (_stripeEventUtilityService.IsSponsoredSubscription(subscription)) + await _organizationService.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd); + if (_stripeEventUtilityService.IsSponsoredSubscription(subscription) && currentPeriodEnd.HasValue) { - await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd.Value); } await RemovePasswordManagerCouponIfRemovingSecretsManagerTrialAsync(parsedEvent, subscription); } else if (userId.HasValue) { - await _userService.UpdatePremiumExpirationAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.UpdatePremiumExpirationAsync(userId.Value, currentPeriodEnd); } } @@ -193,6 +203,13 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } + private async Task IsPremiumSubscriptionAsync(Subscription subscription) + { + var premiumPlans = await _pricingClient.ListPremiumPlans(); + var premiumPriceIds = premiumPlans.SelectMany(p => new[] { p.Seat.StripePriceId, p.Storage.StripePriceId }).ToHashSet(); + return subscription.Items.Any(i => premiumPriceIds.Contains(i.Price.Id)); + } + /// /// Checks if the provider subscription status has changed from a non-active to an active status type /// If the previous status is already active(active,past-due,trialing),canceled,or null, then this will return false. @@ -280,9 +297,8 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler ?.Coupon ?.Id == "sm-standalone"; - var subscriptionHasSecretsManagerTrial = subscription.Discount - ?.Coupon - ?.Id == "sm-standalone"; + var subscriptionHasSecretsManagerTrial = subscription.Discounts.Select(discount => discount.Coupon.Id) + .Contains(StripeConstants.CouponIDs.SecretsManagerStandalone); if (customerHasSecretsManagerTrial) { @@ -318,13 +334,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler Event parsedEvent, Subscription currentSubscription) { - var providerPortalTakeover = _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); - - if (!providerPortalTakeover) - { - return; - } - var provider = await _providerRepository.GetByIdAsync(providerId); if (provider == null) { @@ -340,22 +349,17 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler { var previousSubscription = parsedEvent.Data.PreviousAttributes.ToObject() as Subscription; - var updateIsSubscriptionGoingUnpaid = previousSubscription is - { - Status: + if (previousSubscription is + { + Status: StripeSubscriptionStatus.Trialing or StripeSubscriptionStatus.Active or StripeSubscriptionStatus.PastDue - } && currentSubscription is - { - Status: StripeSubscriptionStatus.Unpaid, - LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" - }; - - var updateIsManualSuspensionViaMetadata = CheckForManualSuspensionViaMetadata( - previousSubscription, currentSubscription); - - if (updateIsSubscriptionGoingUnpaid || updateIsManualSuspensionViaMetadata) + } && currentSubscription is + { + Status: StripeSubscriptionStatus.Unpaid, + LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" + }) { if (currentSubscription.TestClock != null) { @@ -366,14 +370,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler var subscriptionUpdateOptions = new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }; - if (updateIsManualSuspensionViaMetadata) - { - subscriptionUpdateOptions.Metadata = new Dictionary - { - ["suspended_provider_via_webhook_at"] = DateTime.UtcNow.ToString(CultureInfo.InvariantCulture) - }; - } - await _stripeFacade.UpdateSubscription(currentSubscription.Id, subscriptionUpdateOptions); } } @@ -396,37 +392,4 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } } - - private static bool CheckForManualSuspensionViaMetadata( - Subscription? previousSubscription, - Subscription currentSubscription) - { - /* - * When metadata on a subscription is updated, we'll receive an event that has: - * Previous Metadata: { newlyAddedKey: null } - * Current Metadata: { newlyAddedKey: newlyAddedValue } - * - * As such, our check for a manual suspension must ensure that the 'previous_attributes' does contain the - * 'metadata' property, but also that the "suspend_provider" key in that metadata is set to null. - * - * If we don't do this and instead do a null coalescing check on 'previous_attributes?.metadata?.TryGetValue', - * we'll end up marking an event where 'previous_attributes.metadata' = null (which could be any subscription update - * that does not update the metadata) the same as a manual suspension. - */ - const string key = "suspend_provider"; - - if (previousSubscription is not { Metadata: not null } || - !previousSubscription.Metadata.TryGetValue(key, out var previousValue)) - { - return false; - } - - if (previousValue == null) - { - return !string.IsNullOrEmpty( - currentSubscription.Metadata.TryGetValue(key, out var currentValue) ? currentValue : null); - } - - return false; - } } diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index e5675f7c0a..004828dc48 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,7 +1,5 @@ -// FIXME: Update this file to be null safe and then delete the line below - -#nullable disable - +using System.Globalization; +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; @@ -10,14 +8,23 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; +using Bit.Core.Entities; +using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Premium; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; +using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Repositories; using Bit.Core.Services; using Stripe; using Event = Stripe.Event; +using Plan = Bit.Core.Models.StaticStore.Plan; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; namespace Bit.Billing.Services.Implementations; +using static StripeConstants; + public class UpcomingInvoiceHandler( IGetPaymentMethodQuery getPaymentMethodQuery, ILogger logger, @@ -29,138 +36,460 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand) + IValidateSponsorshipCommand validateSponsorshipCommand, + IMailer mailer, + IFeatureService featureService) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) { var invoice = await stripeEventService.GetInvoice(parsedEvent); - if (string.IsNullOrEmpty(invoice.SubscriptionId)) + var customer = + await stripeFacade.GetCustomer(invoice.CustomerId, + new CustomerGetOptions { Expand = ["subscriptions", "tax", "tax_ids"] }); + + var subscription = customer.Subscriptions.FirstOrDefault(); + + if (subscription == null) { - logger.LogInformation("Received 'invoice.upcoming' Event with ID '{eventId}' that did not include a Subscription ID", parsedEvent.Id); return; } - var subscription = await stripeFacade.GetSubscription(invoice.SubscriptionId, new SubscriptionGetOptions - { - Expand = ["customer.tax", "customer.tax_ids"] - }); - var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); if (organizationId.HasValue) { - var organization = await organizationRepository.GetByIdAsync(organizationId.Value); - - if (organization == null) - { - return; - } - - await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id); - - var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - - if (!plan.IsAnnual) - { - return; - } - - if (stripeEventUtilityService.IsSponsoredSubscription(subscription)) - { - var sponsorshipIsValid = await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value); - - if (!sponsorshipIsValid) - { - /* - * If the sponsorship is invalid, then the subscription was updated to use the regular families plan - * price. Given that this is the case, we need the new invoice amount - */ - invoice = await stripeFacade.GetInvoice(subscription.LatestInvoiceId); - } - } - - await SendUpcomingInvoiceEmailsAsync(new List { organization.BillingEmail }, invoice); - - /* - * TODO: https://bitwarden.atlassian.net/browse/PM-4862 - * Disabling this as part of a hot fix. It needs to check whether the organization - * belongs to a Reseller provider and only send an email to the organization owners if it does. - * It also requires a new email template as the current one contains too much billing information. - */ - - // var ownerEmails = await _organizationRepository.GetOwnerEmailAddressesById(organization.Id); - - // await SendEmails(ownerEmails); + await HandleOrganizationUpcomingInvoiceAsync( + organizationId.Value, + parsedEvent, + invoice, + customer, + subscription); } else if (userId.HasValue) { - var user = await userRepository.GetByIdAsync(userId.Value); - - if (user == null) - { - return; - } - - if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation()) - { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", - user.Id, - parsedEvent.Id); - } - } - - if (user.Premium) - { - await SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice); - } + await HandlePremiumUsersUpcomingInvoiceAsync( + userId.Value, + parsedEvent, + invoice, + customer, + subscription); } else if (providerId.HasValue) { - var provider = await providerRepository.GetByIdAsync(providerId.Value); + await HandleProviderUpcomingInvoiceAsync( + providerId.Value, + parsedEvent, + invoice, + customer, + subscription); + } + } - if (provider == null) + #region Organizations + + private async Task HandleOrganizationUpcomingInvoiceAsync( + Guid organizationId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) + { + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + logger.LogWarning("Could not find Organization ({OrganizationID}) for '{EventType}' event ({EventID})", + organizationId, @event.Type, @event.Id); + return; + } + + await AlignOrganizationTaxConcernsAsync(organization, subscription, customer, @event.Id); + + var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + var milestone3 = featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3); + + var subscriptionAligned = await AlignOrganizationSubscriptionConcernsAsync( + organization, + @event, + subscription, + plan, + milestone3); + + /* + * Subscription alignment sends out a different version of our Upcoming Invoice email, so we don't need to continue + * with processing. + */ + if (subscriptionAligned) + { + return; + } + + // Don't send the upcoming invoice email unless the organization's on an annual plan. + if (!plan.IsAnnual) + { + return; + } + + if (stripeEventUtilityService.IsSponsoredSubscription(subscription)) + { + var sponsorshipIsValid = + await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId); + + if (!sponsorshipIsValid) + { + /* + * If the sponsorship is invalid, then the subscription was updated to use the regular families plan + * price. Given that this is the case, we need the new invoice amount + */ + invoice = await stripeFacade.GetInvoice(subscription.LatestInvoiceId); + } + } + + await SendUpcomingInvoiceEmailsAsync([organization.BillingEmail], invoice); + } + + private async Task AlignOrganizationTaxConcernsAsync( + Organization organization, + Subscription subscription, + Customer customer, + string eventId) + { + var nonUSBusinessUse = + organization.PlanType.GetProductTier() != ProductTierType.Families && + customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; + + if (nonUSBusinessUse && customer.TaxExempt != TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + + if (!subscription.AutomaticTax.Enabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + } + + /// + /// Aligns the organization's subscription details with the specified plan and milestone requirements. + /// + /// The organization whose subscription is being updated. + /// The Stripe event associated with this operation. + /// The organization's subscription. + /// The organization's current plan. + /// A flag indicating whether the third milestone is enabled. + /// Whether the operation resulted in an updated subscription. + private async Task AlignOrganizationSubscriptionConcernsAsync( + Organization organization, + Event @event, + Subscription subscription, + Plan plan, + bool milestone3) + { + // currently these are the only plans that need aligned and both require the same flag and share most of the logic + if (!milestone3 || plan.Type is not (PlanType.FamiliesAnnually2019 or PlanType.FamiliesAnnually2025)) + { + return false; + } + + var passwordManagerItem = + subscription.Items.FirstOrDefault(item => item.Price.Id == plan.PasswordManager.StripePlanId); + + if (passwordManagerItem == null) + { + logger.LogWarning("Could not find Organization's ({OrganizationId}) password manager item while processing '{EventType}' event ({EventID})", + organization.Id, @event.Type, @event.Id); + return false; + } + + var familiesPlan = await pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually); + + organization.PlanType = familiesPlan.Type; + organization.Plan = familiesPlan.Name; + organization.UsersGetPremium = familiesPlan.UsersGetPremium; + organization.Seats = familiesPlan.PasswordManager.BaseSeats; + + var options = new SubscriptionUpdateOptions + { + Items = + [ + new SubscriptionItemOptions + { + Id = passwordManagerItem.Id, + Price = familiesPlan.PasswordManager.StripePlanId + } + ], + ProrationBehavior = ProrationBehavior.None + }; + + if (plan.Type == PlanType.FamiliesAnnually2019) + { + options.Discounts = + [ + new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone3SubscriptionDiscount } + ]; + + var premiumAccessAddOnItem = subscription.Items.FirstOrDefault(item => + item.Price.Id == plan.PasswordManager.StripePremiumAccessPlanId); + + if (premiumAccessAddOnItem != null) + { + options.Items.Add(new SubscriptionItemOptions + { + Id = premiumAccessAddOnItem.Id, + Deleted = true + }); + } + + var seatAddOnItem = subscription.Items.FirstOrDefault(item => item.Price.Id == "personal-org-seat-annually"); + + if (seatAddOnItem != null) + { + options.Items.Add(new SubscriptionItemOptions + { + Id = seatAddOnItem.Id, + Deleted = true + }); + } + } + + try + { + await organizationRepository.ReplaceAsync(organization); + await stripeFacade.UpdateSubscription(subscription.Id, options); + await SendFamiliesRenewalEmailAsync(organization, familiesPlan, plan); + return true; + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to align subscription concerns for Organization ({OrganizationID}) while processing '{EventType}' event ({EventID})", + organization.Id, + @event.Type, + @event.Id); + return false; + } + } + + #endregion + + #region Premium Users + + private async Task HandlePremiumUsersUpcomingInvoiceAsync( + Guid userId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) + { + var user = await userRepository.GetByIdAsync(userId); + + if (user == null) + { + logger.LogWarning("Could not find User ({UserID}) for '{EventType}' event ({EventID})", + userId, @event.Type, @event.Id); + return; + } + + await AlignPremiumUsersTaxConcernsAsync(user, @event, customer, subscription); + + var milestone2Feature = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); + if (milestone2Feature) + { + var subscriptionAligned = await AlignPremiumUsersSubscriptionConcernsAsync(user, @event, subscription); + + /* + * Subscription alignment sends out a different version of our Upcoming Invoice email, so we don't need to continue + * with processing. + */ + if (subscriptionAligned) { return; } - - await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id); - - await SendProviderUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice, subscription, providerId.Value); } - } - private async Task SendUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice) - { - var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); - - var items = invoice.Lines.Select(i => i.Description).ToList(); - - if (invoice.NextPaymentAttempt.HasValue && invoice.AmountDue > 0) + if (user.Premium) { - await mailService.SendInvoiceUpcoming( - validEmails, - invoice.AmountDue / 100M, - invoice.NextPaymentAttempt.Value, - items, - true); + await SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice); } } - private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice, Subscription subscription, Guid providerId) + private async Task AlignPremiumUsersTaxConcernsAsync( + User user, + Event @event, + Customer customer, + Subscription subscription) + { + if (!subscription.AutomaticTax.Enabled && customer.HasRecognizedTaxLocation()) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", + user.Id, + @event.Id); + } + } + } + + private async Task AlignPremiumUsersSubscriptionConcernsAsync( + User user, + Event @event, + Subscription subscription) + { + var premiumItem = subscription.Items.FirstOrDefault(i => i.Price.Id == Prices.PremiumAnnually); + + if (premiumItem == null) + { + logger.LogWarning("Could not find User's ({UserID}) premium subscription item while processing '{EventType}' event ({EventID})", + user.Id, @event.Type, @event.Id); + return false; + } + + try + { + var plan = await pricingClient.GetAvailablePremiumPlan(); + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + Items = + [ + new SubscriptionItemOptions { Id = premiumItem.Id, Price = plan.Seat.StripePriceId } + ], + Discounts = + [ + new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone2SubscriptionDiscount } + ], + ProrationBehavior = ProrationBehavior.None + }); + await SendPremiumRenewalEmailAsync(user, plan); + return true; + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to update user's ({UserID}) subscription price id while processing event with ID {EventID}", + user.Id, + @event.Id); + return false; + } + } + + #endregion + + #region Providers + + private async Task HandleProviderUpcomingInvoiceAsync( + Guid providerId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogWarning("Could not find Provider ({ProviderID}) for '{EventType}' event ({EventID})", + providerId, @event.Type, @event.Id); + return; + } + + await AlignProviderTaxConcernsAsync(provider, subscription, customer, @event.Id); + + if (!string.IsNullOrEmpty(provider.BillingEmail)) + { + await SendProviderUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice, subscription, providerId); + } + } + + private async Task AlignProviderTaxConcernsAsync( + Provider provider, + Subscription subscription, + Customer customer, + string eventId) + { + if (customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && + customer.TaxExempt != TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", + provider.Id, + eventId); + } + } + + if (!subscription.AutomaticTax.Enabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", + provider.Id, + eventId); + } + } + } + + private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice, + Subscription subscription, Guid providerId) { var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); @@ -196,94 +525,114 @@ public class UpcomingInvoiceHandler( } } - private async Task AlignOrganizationTaxConcernsAsync( + #endregion + + #region Shared + + private async Task SendUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice) + { + var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); + + var items = invoice.Lines.Select(i => i.Description).ToList(); + + if (invoice is { NextPaymentAttempt: not null, AmountDue: > 0 }) + { + await mailService.SendInvoiceUpcoming( + validEmails, + invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, + items, + true); + } + } + + private async Task SendFamiliesRenewalEmailAsync( Organization organization, - Subscription subscription, - string eventId) + Plan familiesPlan, + Plan planBeforeAlignment) { - var nonUSBusinessUse = - organization.PlanType.GetProductTier() != ProductTierType.Families && - subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; - - if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) + await (planBeforeAlignment switch { - try - { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", - organization.Id, - eventId); - } - } - - if (!subscription.AutomaticTax.Enabled) - { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", - organization.Id, - eventId); - } - } + { Type: PlanType.FamiliesAnnually2025 } => SendFamilies2020RenewalEmailAsync(organization, familiesPlan), + { Type: PlanType.FamiliesAnnually2019 } => SendFamilies2019RenewalEmailAsync(organization, familiesPlan), + _ => throw new InvalidOperationException("Unsupported families plan in SendFamiliesRenewalEmailAsync().") + }); } - private async Task AlignProviderTaxConcernsAsync( - Provider provider, - Subscription subscription, - string eventId) + private async Task SendFamilies2020RenewalEmailAsync(Organization organization, Plan familiesPlan) { - if (subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && - subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) + var email = new Families2020RenewalMail { - try + ToEmails = [organization.BillingEmail], + View = new Families2020RenewalMailView { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", - provider.Id, - eventId); + MonthlyRenewalPrice = (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) } + }; + + await mailer.SendEmail(email); + } + + private async Task SendFamilies2019RenewalEmailAsync(Organization organization, Plan familiesPlan) + { + var coupon = await stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + if (coupon == null) + { + throw new InvalidOperationException($"Coupon for sending families 2019 email id:{CouponIDs.Milestone3SubscriptionDiscount} not found"); } - if (!subscription.AutomaticTax.Enabled) + if (coupon.PercentOff == null) { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", - provider.Id, - eventId); - } + throw new InvalidOperationException($"coupon.PercentOff for sending families 2019 email id:{CouponIDs.Milestone3SubscriptionDiscount} is null"); } + + var discountedAnnualRenewalPrice = familiesPlan.PasswordManager.BasePrice * (100 - coupon.PercentOff.Value) / 100; + + var email = new Families2019RenewalMail + { + ToEmails = [organization.BillingEmail], + View = new Families2019RenewalMailView + { + BaseMonthlyRenewalPrice = (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")), + BaseAnnualRenewalPrice = familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")), + DiscountAmount = $"{coupon.PercentOff}%", + DiscountedAnnualRenewalPrice = discountedAnnualRenewalPrice.ToString("C", new CultureInfo("en-US")) + } + }; + + await mailer.SendEmail(email); } + + private async Task SendPremiumRenewalEmailAsync( + User user, + PremiumPlan premiumPlan) + { + var coupon = await stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount); + if (coupon == null) + { + throw new InvalidOperationException($"Coupon for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} not found"); + } + + if (coupon.PercentOff == null) + { + throw new InvalidOperationException($"coupon.PercentOff for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} is null"); + } + + var discountedAnnualRenewalPrice = premiumPlan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; + + var email = new PremiumRenewalMail + { + ToEmails = [user.Email], + View = new PremiumRenewalMailView + { + BaseMonthlyRenewalPrice = (premiumPlan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")), + DiscountAmount = $"{coupon.PercentOff}%", + DiscountedMonthlyRenewalPrice = (discountedAnnualRenewalPrice / 12).ToString("C", new CultureInfo("en-US")) + } + }; + + await mailer.SendEmail(email); + } + + #endregion } diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index 5b464d5ef6..30f4f5f562 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -2,7 +2,6 @@ #nullable disable using System.Globalization; -using System.Net.Http.Headers; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; using Bit.Commercial.Core.Utilities; @@ -10,7 +9,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Context; using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories.Noop; -using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -51,9 +49,6 @@ public class Startup // Repositories services.AddDatabaseRepositories(globalSettings); - // BitPay Client - services.AddSingleton(); - // PayPal IPN Client services.AddHttpClient(); @@ -102,13 +97,6 @@ public class Startup // Authentication services.AddAuthentication(); - // Set up HttpClients - services.AddHttpClient("FreshdeskApi"); - services.AddHttpClient("OnyxApi", client => - { - client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", billingSettings.Onyx.ApiKey); - }); - services.AddScoped(); services.AddScoped(); services.AddScoped(); @@ -132,12 +120,8 @@ public class Startup public void Configure( IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + IWebHostEnvironment env) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Billing/appsettings.Development.json b/src/Billing/appsettings.Development.json index 7c4889c22f..77057fde7f 100644 --- a/src/Billing/appsettings.Development.json +++ b/src/Billing/appsettings.Development.json @@ -32,9 +32,5 @@ "connectionString": "UseDevelopmentStorage=true" } }, - "billingSettings": { - "onyx": { - "personaId": 68 - } - } + "pricingUri": "https://billingpricing.qa.bitwarden.pw" } diff --git a/src/Billing/appsettings.Production.json b/src/Billing/appsettings.Production.json index 4be5d51a52..819986181f 100644 --- a/src/Billing/appsettings.Production.json +++ b/src/Billing/appsettings.Production.json @@ -26,10 +26,7 @@ "payPal": { "production": true, "businessId": "4ZDA7DLUUJGMN" - }, - "onyx": { - "personaId": 7 - } + } }, "Logging": { "IncludeScopes": false, diff --git a/src/Billing/appsettings.json b/src/Billing/appsettings.json index 0074b5aafe..7093b6a923 100644 --- a/src/Billing/appsettings.json +++ b/src/Billing/appsettings.json @@ -30,9 +30,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" @@ -57,30 +54,13 @@ "billingSettings": { "jobsKey": "SECRET", "stripeWebhookKey": "SECRET", - "stripeWebhookSecret": "SECRET", - "stripeWebhookSecret20231016": "SECRET", - "stripeWebhookSecret20240620": "SECRET", + "stripeWebhookSecret20250827Basil": "SECRET", "bitPayWebhookKey": "SECRET", "appleWebhookKey": "SECRET", "payPal": { "production": false, "businessId": "AD3LAUZSNVPJY", "webhookKey": "SECRET" - }, - "freshdesk": { - "apiKey": "SECRET", - "webhookKey": "SECRET", - "region": "US", - "userFieldName": "cf_user", - "orgFieldName": "cf_org", - "removeNewlinesInReplies": true, - "autoReplyGreeting": "Greetings,

    Thank you for contacting Bitwarden. The reply below was generated by our AI agent based on your message:

    ", - "autoReplySalutation": "

    If this response doesn’t fully address your question, simply reply to this email and a member of our Customer Success team will be happy to assist you further.

    Best Regards,
    The Bitwarden Customer Success Team

    " - }, - "onyx": { - "apiKey": "SECRET", - "baseUrl": "https://cloud.onyx.app/api", - "personaId": 7 - } + } } } diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 7933990e74..338b150de6 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -129,6 +129,16 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable ///
    public bool SyncSeats { get; set; } + /// + /// If set to true, user accounts created within the organization are automatically confirmed without requiring additional verification steps. + /// + public bool UseAutomaticUserConfirmation { get; set; } + + /// + /// If set to true, the organization has phishing protection enabled. + /// + public bool UsePhishingBlocker { get; set; } + public void SetNewId() { if (Id == default(Guid)) @@ -328,5 +338,7 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable UseRiskInsights = license.UseRiskInsights; UseOrganizationDomains = license.UseOrganizationDomains; UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation; + UsePhishingBlocker = license.UsePhishingBlocker; } } diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 3ac14d67f3..bd6daf7cdf 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -21,6 +21,7 @@ public enum PolicyType : byte UriMatchDefaults = 16, AutotypeDefaultSetting = 17, AutomaticUserConfirmation = 18, + BlockClaimedDomainAccountCreation = 19, } public static class PolicyTypeExtensions @@ -45,13 +46,14 @@ public static class PolicyTypeExtensions PolicyType.MaximumVaultTimeout => "Vault timeout", PolicyType.DisablePersonalVaultExport => "Remove individual vault export", PolicyType.ActivateAutofill => "Active auto-fill", - PolicyType.AutomaticAppLogIn => "Automatically log in users for allowed applications", + PolicyType.AutomaticAppLogIn => "Automatic login with SSO", PolicyType.FreeFamiliesSponsorshipPolicy => "Remove Free Bitwarden Families sponsorship", PolicyType.RemoveUnlockWithPin => "Remove unlock with PIN", PolicyType.RestrictedItemTypesPolicy => "Restricted item types", PolicyType.UriMatchDefaults => "URI match defaults", PolicyType.AutotypeDefaultSetting => "Autotype default setting", PolicyType.AutomaticUserConfirmation => "Automatically confirm invited users", + PolicyType.BlockClaimedDomainAccountCreation => "Block account creation for claimed domains", }; } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegration.cs deleted file mode 100644 index 8785a74896..0000000000 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegration.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -public record DatadogIntegration(string ApiKey, Uri Uri); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationHandlerResult.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationHandlerResult.cs deleted file mode 100644 index 8db054561b..0000000000 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationHandlerResult.cs +++ /dev/null @@ -1,16 +0,0 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -public class IntegrationHandlerResult -{ - public IntegrationHandlerResult(bool success, IIntegrationMessage message) - { - Success = success; - Message = message; - } - - public bool Success { get; set; } = false; - public bool Retryable { get; set; } = false; - public IIntegrationMessage Message { get; set; } - public DateTime? DelayUntilDate { get; set; } - public string FailureReason { get; set; } = string.Empty; -} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegration.cs deleted file mode 100644 index dc2733c889..0000000000 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegration.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -public record SlackIntegration(string Token); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs deleted file mode 100644 index 5b4fae0c76..0000000000 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs +++ /dev/null @@ -1,3 +0,0 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -public record SlackIntegrationConfiguration(string ChannelId); diff --git a/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs new file mode 100644 index 0000000000..0368678641 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs @@ -0,0 +1,57 @@ +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; + +namespace Bit.Core.AdminConsole.Models.Data; + +/// +/// Interface defining common organization details properties shared between +/// regular organization users and provider organization users for profile endpoints. +/// +public interface IProfileOrganizationDetails +{ + Guid? UserId { get; set; } + Guid OrganizationId { get; set; } + string Name { get; set; } + bool Enabled { get; set; } + PlanType PlanType { get; set; } + bool UsePolicies { get; set; } + bool UseSso { get; set; } + bool UseKeyConnector { get; set; } + bool UseScim { get; set; } + bool UseGroups { get; set; } + bool UseDirectory { get; set; } + bool UseEvents { get; set; } + bool UseTotp { get; set; } + bool Use2fa { get; set; } + bool UseApi { get; set; } + bool UseResetPassword { get; set; } + bool SelfHost { get; set; } + bool UsersGetPremium { get; set; } + bool UseCustomPermissions { get; set; } + bool UseSecretsManager { get; set; } + int? Seats { get; set; } + short? MaxCollections { get; set; } + short? MaxStorageGb { get; set; } + string? Identifier { get; set; } + string? Key { get; set; } + string? ResetPasswordKey { get; set; } + string? PublicKey { get; set; } + string? PrivateKey { get; set; } + string? SsoExternalId { get; set; } + string? Permissions { get; set; } + Guid? ProviderId { get; set; } + string? ProviderName { get; set; } + ProviderType? ProviderType { get; set; } + bool? SsoEnabled { get; set; } + string? SsoConfig { get; set; } + bool UsePasswordManager { get; set; } + bool LimitCollectionCreation { get; set; } + bool LimitCollectionDeletion { get; set; } + bool AllowAdminAccessToAllCollectionItems { get; set; } + bool UseRiskInsights { get; set; } + bool LimitItemDeletion { get; set; } + bool UseAdminSponsoredFamilies { get; set; } + bool UseOrganizationDomains { get; set; } + bool UseAutomaticUserConfirmation { get; set; } + bool UsePhishingBlocker { get; set; } +} diff --git a/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs b/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs new file mode 100644 index 0000000000..0dc6d1c352 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/OrganizationUsers/AcceptedOrganizationUserToConfirm.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.AdminConsole.Models.Data.OrganizationUsers; + +public record AcceptedOrganizationUserToConfirm +{ + public required Guid OrganizationUserId { get; init; } + public required Guid UserId { get; init; } + public required string Key { get; init; } +} diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs index ae91f204e3..7c8389c103 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs @@ -28,6 +28,8 @@ public class OrganizationAbility UseRiskInsights = organization.UseRiskInsights; UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; + UsePhishingBlocker = organization.UsePhishingBlocker; } public Guid Id { get; set; } @@ -49,4 +51,6 @@ public class OrganizationAbility public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs index b7e573c4e6..00b9280337 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs @@ -1,20 +1,18 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.Billing.Enums; using Bit.Core.Utilities; namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; -public class OrganizationUserOrganizationDetails +public class OrganizationUserOrganizationDetails : IProfileOrganizationDetails { public Guid OrganizationId { get; set; } public Guid? UserId { get; set; } public Guid OrganizationUserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } + public string Name { get; set; } = null!; public bool UsePolicies { get; set; } public bool UseSso { get; set; } public bool UseKeyConnector { get; set; } @@ -33,24 +31,24 @@ public class OrganizationUserOrganizationDetails public int? Seats { get; set; } public short? MaxCollections { get; set; } public short? MaxStorageGb { get; set; } - public string Key { get; set; } + public string? Key { get; set; } public Enums.OrganizationUserStatusType Status { get; set; } public Enums.OrganizationUserType Type { get; set; } public bool Enabled { get; set; } public PlanType PlanType { get; set; } - public string SsoExternalId { get; set; } - public string Identifier { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public string? SsoExternalId { get; set; } + public string? Identifier { get; set; } + public string? Permissions { get; set; } + public string? ResetPasswordKey { get; set; } + public string? PublicKey { get; set; } + public string? PrivateKey { get; set; } public Guid? ProviderId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } + public string? ProviderName { get; set; } public ProviderType? ProviderType { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } + public string? FamilySponsorshipFriendlyName { get; set; } public bool? SsoEnabled { get; set; } - public string SsoConfig { get; set; } + public string? SsoConfig { get; set; } public DateTime? FamilySponsorshipLastSyncDate { get; set; } public DateTime? FamilySponsorshipValidUntil { get; set; } public bool? FamilySponsorshipToDelete { get; set; } @@ -66,4 +64,6 @@ public class OrganizationUserOrganizationDetails public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } public bool? IsAdminInitiated { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs index 6d182e197f..00ba706a41 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserUserDetails.cs @@ -20,6 +20,12 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I public string Email { get; set; } public string AvatarColor { get; set; } public string TwoFactorProviders { get; set; } + /// + /// Indicates whether the user has a personal premium subscription. + /// Does not include premium access from organizations - + /// do not use this to check whether the user can access premium features. + /// Null when the organization user is in Invited status (UserId is null). + /// public bool? Premium { get; set; } public OrganizationUserStatusType Status { get; set; } public OrganizationUserType Type { get; set; } @@ -63,11 +69,6 @@ public class OrganizationUserUserDetails : IExternal, ITwoFactorProvidersUser, I return UserId; } - public bool GetPremium() - { - return Premium.GetValueOrDefault(false); - } - public Permissions GetPermissions() { return string.IsNullOrWhiteSpace(Permissions) ? null diff --git a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs index 84ff164943..484320c271 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/SelfHostedOrganizationDetails.cs @@ -154,6 +154,7 @@ public class SelfHostedOrganizationDetails : Organization Status = Status, UseRiskInsights = UseRiskInsights, UseAdminSponsoredFamilies = UseAdminSponsoredFamilies, + UsePhishingBlocker = UsePhishingBlocker, }; } } diff --git a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs index 04281d098e..dcec028dcc 100644 --- a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -1,19 +1,16 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Enums; using Bit.Core.Utilities; namespace Bit.Core.AdminConsole.Models.Data.Provider; -public class ProviderUserOrganizationDetails +public class ProviderUserOrganizationDetails : IProfileOrganizationDetails { public Guid OrganizationId { get; set; } public Guid? UserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } + public string Name { get; set; } = null!; public bool UsePolicies { get; set; } public bool UseSso { get; set; } public bool UseKeyConnector { get; set; } @@ -28,20 +25,22 @@ public class ProviderUserOrganizationDetails public bool SelfHost { get; set; } public bool UsersGetPremium { get; set; } public bool UseCustomPermissions { get; set; } + public bool UseSecretsManager { get; set; } + public bool UsePasswordManager { get; set; } public int? Seats { get; set; } public short? MaxCollections { get; set; } public short? MaxStorageGb { get; set; } - public string Key { get; set; } + public string? Key { get; set; } public ProviderUserStatusType Status { get; set; } public ProviderUserType Type { get; set; } public bool Enabled { get; set; } - public string Identifier { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public string? Identifier { get; set; } + public string? PublicKey { get; set; } + public string? PrivateKey { get; set; } public Guid? ProviderId { get; set; } public Guid? ProviderUserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } + public string? ProviderName { get; set; } public PlanType PlanType { get; set; } public bool LimitCollectionCreation { get; set; } public bool LimitCollectionDeletion { get; set; } @@ -50,5 +49,12 @@ public class ProviderUserOrganizationDetails public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } - public ProviderType ProviderType { get; set; } + public ProviderType? ProviderType { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool? SsoEnabled { get; set; } + public string? SsoConfig { get; set; } + public string? SsoExternalId { get; set; } + public string? Permissions { get; set; } + public string? ResetPasswordKey { get; set; } + public bool UsePhishingBlocker { get; set; } } diff --git a/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs b/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs index af966a6e16..9c699a61cb 100644 --- a/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs +++ b/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs @@ -89,7 +89,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque AuthRequestExpiresAfter = _globalSettings.PasswordlessAuth.AdminRequestExpiration } ); - processor.Process((Exception e) => _logger.LogError(e.Message)); + processor.Process((Exception e) => _logger.LogError("Error processing organization auth request: {Message}", e.Message)); await processor.Save((IEnumerable authRequests) => _authRequestRepository.UpdateManyAsync(authRequests)); await processor.SendPushNotifications((ar) => _pushNotificationService.PushAuthRequestResponseAsync(ar)); await processor.SendApprovalEmailsForProcessedRequests(SendApprovalEmail); @@ -114,7 +114,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque // This should be impossible if (user == null) { - _logger.LogError($"User {authRequest.UserId} not found. Trusted device admin approval email not sent."); + _logger.LogError("User {UserId} not found. Trusted device admin approval email not sent.", authRequest.UserId); return; } diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..5783301a0b --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs @@ -0,0 +1,79 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +public class AdminRecoverAccountCommand(IOrganizationRepository organizationRepository, + IPolicyRepository policyRepository, + IUserRepository userRepository, + IMailService mailService, + IEventService eventService, + IPushNotificationService pushNotificationService, + IUserService userService, + TimeProvider timeProvider) : IAdminRecoverAccountCommand +{ + public async Task RecoverAccountAsync(Guid orgId, + OrganizationUser organizationUser, string newMasterPassword, string key) + { + // Org must be able to use reset password + var org = await organizationRepository.GetByIdAsync(orgId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset."); + } + + // Enterprise policy must be enabled + var resetPasswordPolicy = + await policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Org User must be confirmed and have a ResetPasswordKey + if (organizationUser == null || + organizationUser.Status != OrganizationUserStatusType.Confirmed || + organizationUser.OrganizationId != orgId || + string.IsNullOrEmpty(organizationUser.ResetPasswordKey) || + !organizationUser.UserId.HasValue) + { + throw new BadRequestException("Organization User not valid"); + } + + var user = await userService.GetUserByIdAsync(organizationUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("Cannot reset password of a user with Key Connector."); + } + + var result = await userService.UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = timeProvider.GetUtcNow().UtcDateTime; + user.LastPasswordChangeDate = user.RevisionDate; + user.ForcePasswordReset = true; + user.Key = key; + + await userRepository.ReplaceAsync(user); + await mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.DisplayName()); + await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_AdminResetPassword); + await pushNotificationService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..75babc643e --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs @@ -0,0 +1,24 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +/// +/// A command used to recover an organization user's account by an organization admin. +/// +public interface IAdminRecoverAccountCommand +{ + /// + /// Recovers an organization user's account by resetting their master password. + /// + /// The organization the user belongs to. + /// The organization user being recovered. + /// The user's new master password hash. + /// The user's new master-password-sealed user key. + /// An IdentityResult indicating success or failure. + /// When organization settings, policy, or user state is invalid. + /// When the user does not exist. + Task RecoverAccountAsync(Guid orgId, OrganizationUser organizationUser, + string newMasterPassword, string key); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs index a78dd95260..b9bad6a346 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -18,7 +19,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IGroupRepository _groupRepository; private readonly IEventService _eventService; private readonly IOrganizationService _organizationService; @@ -27,7 +28,7 @@ public class ImportOrganizationUsersAndGroupsCommand : IImportOrganizationUsersA public ImportOrganizationUsersAndGroupsCommand(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IGroupRepository groupRepository, IEventService eventService, IOrganizationService organizationService) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index c03341bbc0..e6cc3da2a2 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -4,8 +4,8 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -24,7 +24,7 @@ public class VerifyOrganizationDomainCommand( IEventService eventService, IGlobalSettings globalSettings, ICurrentContext currentContext, - ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand, IMailService mailService, IOrganizationUserRepository organizationUserRepository, IOrganizationRepository organizationRepository, @@ -131,15 +131,19 @@ public class VerifyOrganizationDomainCommand( await SendVerifiedDomainUserEmailAsync(domain); } - private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) => - await savePolicyCommand.SaveAsync( - new PolicyUpdate - { - OrganizationId = organizationId, - Type = PolicyType.SingleOrg, - Enabled = true, - PerformedBy = actingUser - }); + private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) + { + var policyUpdate = new PolicyUpdate + { + OrganizationId = organizationId, + Type = PolicyType.SingleOrg, + Enabled = true, + PerformedBy = actingUser + }; + + var savePolicyModel = new SavePolicyModel(policyUpdate, actingUser); + await vNextSavePolicyCommand.SaveAsync(savePolicyModel); + } private async Task SendVerifiedDomainUserEmailAsync(OrganizationDomain domain) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs index 63f177b3f3..50f194b578 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; @@ -34,6 +35,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IFeatureService _featureService; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IAutomaticUserConfirmationPolicyEnforcementValidator _automaticUserConfirmationPolicyEnforcementValidator; public AcceptOrgUserCommand( IDataProtectionProvider dataProtectionProvider, @@ -46,7 +48,8 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IFeatureService featureService, - IPolicyRequirementQuery policyRequirementQuery) + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) { // TODO: remove data protector when old token validation removed _dataProtector = dataProtectionProvider.CreateProtector(OrgUserInviteTokenable.DataProtectorPurpose); @@ -60,6 +63,7 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _featureService = featureService; _policyRequirementQuery = policyRequirementQuery; + _automaticUserConfirmationPolicyEnforcementValidator = automaticUserConfirmationPolicyEnforcementValidator; } public async Task AcceptOrgUserByEmailTokenAsync(Guid organizationUserId, User user, string emailToken, @@ -186,13 +190,19 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand } } - // Enforce Single Organization Policy of organization user is trying to join var allOrgUsers = await _organizationUserRepository.GetManyByUserAsync(user.Id); - var hasOtherOrgs = allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId); + + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + await ValidateAutomaticUserConfirmationPolicyAsync(orgUser, allOrgUsers, user); + } + + // Enforce Single Organization Policy of organization user is trying to join var invitedSingleOrgPolicies = await _policyService.GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg, OrganizationUserStatusType.Invited); - if (hasOtherOrgs && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) + if (allOrgUsers.Any(ou => ou.OrganizationId != orgUser.OrganizationId) + && invitedSingleOrgPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) { throw new BadRequestException("You may not join this organization until you leave or remove all other organizations."); } @@ -255,4 +265,22 @@ public class AcceptOrgUserCommand : IAcceptOrgUserCommand } } } + + private async Task ValidateAutomaticUserConfirmationPolicyAsync(OrganizationUser orgUser, + ICollection allOrgUsers, User user) + { + var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, + allOrgUsers.Append(orgUser), + user))) + .Match( + error => error.Message, + _ => string.Empty + ); + + if (!string.IsNullOrEmpty(error)) + { + throw new BadRequestException(error); + } + } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs new file mode 100644 index 0000000000..67b5f0da80 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs @@ -0,0 +1,186 @@ +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using OneOf.Types; +using CommandResult = Bit.Core.AdminConsole.Utilities.v2.Results.CommandResult; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public class AutomaticallyConfirmOrganizationUserCommand(IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IAutomaticallyConfirmOrganizationUsersValidator validator, + IEventService eventService, + IMailService mailService, + IUserRepository userRepository, + IPushRegistrationService pushRegistrationService, + IDeviceRepository deviceRepository, + IPushNotificationService pushNotificationService, + IPolicyRequirementQuery policyRequirementQuery, + ICollectionRepository collectionRepository, + TimeProvider timeProvider, + ILogger logger) : IAutomaticallyConfirmOrganizationUserCommand +{ + public async Task AutomaticallyConfirmOrganizationUserAsync(AutomaticallyConfirmOrganizationUserRequest request) + { + var validatorRequest = await RetrieveDataAsync(request); + + var validatedData = await validator.ValidateAsync(validatorRequest); + + return await validatedData.Match>( + error => Task.FromResult(new CommandResult(error)), + async _ => + { + var userToConfirm = new AcceptedOrganizationUserToConfirm + { + OrganizationUserId = validatedData.Request.OrganizationUser!.Id, + UserId = validatedData.Request.OrganizationUser.UserId!.Value, + Key = validatedData.Request.Key + }; + + // This operation is idempotent. If false, the user is already confirmed and no additional side effects are required. + if (!await organizationUserRepository.ConfirmOrganizationUserAsync(userToConfirm)) + { + return new None(); + } + + await CreateDefaultCollectionsAsync(validatedData.Request); + + await Task.WhenAll( + LogOrganizationUserConfirmedEventAsync(validatedData.Request), + SendConfirmedOrganizationUserEmailAsync(validatedData.Request), + SyncOrganizationKeysAsync(validatedData.Request) + ); + + return new None(); + } + ); + } + + private async Task SyncOrganizationKeysAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + await DeleteDeviceRegistrationAsync(request); + await PushSyncOrganizationKeysAsync(request); + } + + private async Task CreateDefaultCollectionsAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + if (!await ShouldCreateDefaultCollectionAsync(request)) + { + return; + } + + await collectionRepository.CreateAsync( + new Collection + { + OrganizationId = request.Organization!.Id, + Name = request.DefaultUserCollectionName, + Type = CollectionType.DefaultUserCollection + }, + groups: null, + [new CollectionAccessSelection + { + Id = request.OrganizationUser!.Id, + Manage = true + }]); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to create default collection for user."); + } + } + + /// + /// Determines whether a default collection should be created for an organization user during the confirmation process. + /// + /// + /// The validation request containing information about the user, organization, and collection settings. + /// + /// The result is a boolean value indicating whether a default collection should be created. + private async Task ShouldCreateDefaultCollectionAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) => + !string.IsNullOrWhiteSpace(request.DefaultUserCollectionName) + && (await policyRequirementQuery.GetAsync(request.OrganizationUser!.UserId!.Value)) + .RequiresDefaultCollectionOnConfirm(request.Organization!.Id); + + private async Task PushSyncOrganizationKeysAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + await pushNotificationService.PushSyncOrgKeysAsync(request.OrganizationUser!.UserId!.Value); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to push organization keys."); + } + } + + private async Task LogOrganizationUserConfirmedEventAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + await eventService.LogOrganizationUserEventAsync(request.OrganizationUser, + EventType.OrganizationUser_AutomaticallyConfirmed, + timeProvider.GetUtcNow().UtcDateTime); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to log OrganizationUser_AutomaticallyConfirmed event."); + } + } + + private async Task SendConfirmedOrganizationUserEmailAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + var user = await userRepository.GetByIdAsync(request.OrganizationUser!.UserId!.Value); + + await mailService.SendOrganizationConfirmedEmailAsync(request.Organization!.Name, + user!.Email, + request.OrganizationUser.AccessSecretsManager); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to send OrganizationUserConfirmed."); + } + } + + private async Task DeleteDeviceRegistrationAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + try + { + var devices = (await deviceRepository.GetManyByUserIdAsync(request.OrganizationUser!.UserId!.Value)) + .Where(d => !string.IsNullOrWhiteSpace(d.PushToken)) + .Select(d => d.Id.ToString()); + + await pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices, request.Organization!.Id.ToString()); + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to delete device registration."); + } + } + + private async Task RetrieveDataAsync( + AutomaticallyConfirmOrganizationUserRequest request) + { + return new AutomaticallyConfirmOrganizationUserValidationRequest + { + OrganizationUserId = request.OrganizationUserId, + OrganizationId = request.OrganizationId, + Key = request.Key, + DefaultUserCollectionName = request.DefaultUserCollectionName, + PerformedBy = request.PerformedBy, + OrganizationUser = await organizationUserRepository.GetByIdAsync(request.OrganizationUserId), + Organization = await organizationRepository.GetByIdAsync(request.OrganizationId) + }; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs new file mode 100644 index 0000000000..fcc8dacf66 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserRequest.cs @@ -0,0 +1,29 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +/// +/// Automatically Confirm User Command Request +/// +public record AutomaticallyConfirmOrganizationUserRequest +{ + public required Guid OrganizationUserId { get; init; } + public required Guid OrganizationId { get; init; } + public required string Key { get; init; } + public required string DefaultUserCollectionName { get; init; } + public required IActingUser PerformedBy { get; init; } +} + +/// +/// Automatically Confirm User Validation Request +/// +/// +/// This is used to hold retrieved data and pass it to the validator +/// +public record AutomaticallyConfirmOrganizationUserValidationRequest : AutomaticallyConfirmOrganizationUserRequest +{ + public OrganizationUser? OrganizationUser { get; set; } + public Organization? Organization { get; set; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs new file mode 100644 index 0000000000..3375120516 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs @@ -0,0 +1,125 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public class AutomaticallyConfirmOrganizationUsersValidator( + IOrganizationUserRepository organizationUserRepository, + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator, + IUserService userService, + IPolicyRepository policyRepository) : IAutomaticallyConfirmOrganizationUsersValidator +{ + public async Task> ValidateAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request) + { + // User must exist + if (request is { OrganizationUser: null } || request.OrganizationUser is { UserId: null }) + { + return Invalid(request, new UserNotFoundError()); + } + + // Organization must exist + if (request is { Organization: null }) + { + return Invalid(request, new OrganizationNotFound()); + } + + // User must belong to the organization + if (request.OrganizationUser.OrganizationId != request.Organization.Id) + { + return Invalid(request, new OrganizationUserIdIsInvalid()); + } + + // User must be accepted + if (request is { OrganizationUser.Status: not OrganizationUserStatusType.Accepted }) + { + return Invalid(request, new UserIsNotAccepted()); + } + + // User must be of type User + if (request is { OrganizationUser.Type: not OrganizationUserType.User }) + { + return Invalid(request, new UserIsNotUserType()); + } + + if (!await OrganizationHasAutomaticallyConfirmUsersPolicyEnabledAsync(request)) + { + return Invalid(request, new AutomaticallyConfirmUsersPolicyIsNotEnabled()); + } + + if (!await OrganizationUserConformsToTwoFactorRequiredPolicyAsync(request)) + { + return Invalid(request, new UserDoesNotHaveTwoFactorEnabled()); + } + + if (await OrganizationUserConformsToAutomaticUserConfirmationPolicyAsync(request) is { } error) + { + return Invalid(request, error); + } + + return Valid(request); + } + + private async Task OrganizationHasAutomaticallyConfirmUsersPolicyEnabledAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) => + await policyRepository.GetByOrganizationIdTypeAsync(request.OrganizationId, PolicyType.AutomaticUserConfirmation) is { Enabled: true } + && request.Organization is { UseAutomaticUserConfirmation: true }; + + private async Task OrganizationUserConformsToTwoFactorRequiredPolicyAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) + { + if ((await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync([request.OrganizationUser!.UserId!.Value])) + .Any(x => x.userId == request.OrganizationUser.UserId && x.twoFactorIsEnabled)) + { + return true; + } + + return !(await policyRequirementQuery.GetAsync(request.OrganizationUser.UserId!.Value)) + .IsTwoFactorRequiredForOrganization(request.Organization!.Id); + } + + /// + /// Validates whether the specified organization user complies with the automatic user confirmation policy. + /// This includes checks across all organizations the user is associated with to ensure they meet the compliance criteria. + /// + /// We are not checking single organization policy compliance here because automatically confirm users policy enforces + /// a stricter version and applies to all users. If you are compliant with Auto Confirm, you'll be in compliance with + /// Single Org. + /// + /// + /// The request model encapsulates the current organization, the user being validated, and all organization users associated + /// with that user. + /// + /// + /// An if the user fails to meet the automatic user confirmation policy, or null if the validation succeeds. + /// + private async Task OrganizationUserConformsToAutomaticUserConfirmationPolicyAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request) + { + var allOrganizationUsersForUser = await organizationUserRepository + .GetManyByUserAsync(request.OrganizationUser!.UserId!.Value); + + var user = await userService.GetUserByIdAsync(request.OrganizationUser!.UserId!.Value); + + return (await automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest( + request.OrganizationId, + allOrganizationUsersForUser, + user))) + .Match( + error => error, + _ => null + ); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs new file mode 100644 index 0000000000..e65db00f73 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/Errors.cs @@ -0,0 +1,16 @@ +using Bit.Core.AdminConsole.Utilities.v2; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public record OrganizationNotFound() : NotFoundError("Invalid organization"); +public record FailedToWriteToEventLog() : InternalError("Failed to write to event log"); +public record UserIsNotUserType() : BadRequestError("Only organization users with the User role can be automatically confirmed"); +public record UserIsNotAccepted() : BadRequestError("Cannot confirm user that has not accepted the invitation."); +public record OrganizationUserIdIsInvalid() : BadRequestError("Invalid organization user id."); +public record UserDoesNotHaveTwoFactorEnabled() : BadRequestError("User does not have two-step login enabled."); +public record UserCannotBelongToAnotherOrganization() : BadRequestError("Cannot confirm this member to the organization until they leave or remove all other organizations"); +public record OtherOrganizationDoesNotAllowOtherMembership() : BadRequestError("Cannot confirm this member to the organization because they are in another organization which forbids it."); +public record AutomaticallyConfirmUsersPolicyIsNotEnabled() : BadRequestError("Cannot confirm this member because the Automatically Confirm Users policy is not enabled."); +public record ProviderUsersCannotJoin() : BadRequestError("An organization the user is a part of has enabled Automatic User Confirmation policy, and it does not support provider users joining."); +public record UserCannotJoinProvider() : BadRequestError("An organization the user is a part of has enabled Automatic User Confirmation policy, and it does not support the user joining a provider."); +public record CurrentOrganizationUserIsNotPresentInRequest() : BadRequestError("The current organization user does not exist in the request."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs new file mode 100644 index 0000000000..544b65b53f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/IAutomaticallyConfirmOrganizationUsersValidator.cs @@ -0,0 +1,9 @@ +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; + +public interface IAutomaticallyConfirmOrganizationUsersValidator +{ + Task> ValidateAsync( + AutomaticallyConfirmOrganizationUserValidationRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md new file mode 100644 index 0000000000..063b2f6a5c --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/README.md @@ -0,0 +1,22 @@ +# Automatic User Confirmation + +Owned by: admin-console + +Automatic confirmation requests are server driven events that are sent to the admin's client where via a background service the confirmation will occur. The basic model +for the workflow is as follows: + +- The Api server sends an invite email to a user. +- The user accepts the invite request, which is sent back to the Api server +- The Api server sends a push-notification with the OrganizationId and UserId to a client admin session. +- The Client performs the key exchange in the background and POSTs the ConfirmRequest back to the Api server +- The Api server runs the OrgUser_Confirm sproc to confirm the user in the DB + +This Feature has the following security measures in place in order to achieve our security goals: + +- The single organization exemption for admins/owners is removed for this policy. + - This is enforced by preventing enabling the policy and organization plan feature if there are non-compliant users +- Emergency access is removed for all organization users +- Automatic confirmation will only apply to the User role (You cannot auto confirm admins/owners to an organization) +- The organization has no members with the Provider user type. + - This will also prevent the policy and organization plan feature from being enabled + - This will prevent sending organization invites to provider users diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 2fbe6be5c6..b6b49e93e9 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -33,6 +34,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IFeatureService _featureService; private readonly ICollectionRepository _collectionRepository; + private readonly IAutomaticUserConfirmationPolicyEnforcementValidator _automaticUserConfirmationPolicyEnforcementValidator; public ConfirmOrganizationUserCommand( IOrganizationRepository organizationRepository, @@ -47,7 +49,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand IDeviceRepository deviceRepository, IPolicyRequirementQuery policyRequirementQuery, IFeatureService featureService, - ICollectionRepository collectionRepository) + ICollectionRepository collectionRepository, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -62,6 +65,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand _policyRequirementQuery = policyRequirementQuery; _featureService = featureService; _collectionRepository = collectionRepository; + _automaticUserConfirmationPolicyEnforcementValidator = automaticUserConfirmationPolicyEnforcementValidator; } public async Task ConfirmUserAsync(Guid organizationId, Guid organizationUserId, string key, @@ -127,6 +131,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand var organization = await _organizationRepository.GetByIdAsync(organizationId); var allUsersOrgs = await _organizationUserRepository.GetManyByManyUsersAsync(validSelectedUserIds); + var users = await _userRepository.GetManyAsync(validSelectedUserIds); var usersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(validSelectedUserIds); @@ -188,6 +193,25 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand await ValidateTwoFactorAuthenticationPolicyAsync(user, organizationId, userTwoFactorEnabled); var hasOtherOrgs = userOrgs.Any(ou => ou.OrganizationId != organizationId); + + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var error = (await _automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationId, + userOrgs, + user))) + .Match( + error => new BadRequestException(error.Message), + _ => null + ); + + if (error is not null) + { + throw error; + } + } + var singleOrgPolicies = await _policyService.GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg); var otherSingleOrgPolicies = singleOrgPolicies.Where(p => p.OrganizationId != organizationId); @@ -267,8 +291,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var organizationDataOwnershipPolicy = - await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value); + var organizationDataOwnershipPolicy = await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value); if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId)) { return; @@ -311,8 +334,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var policyEligibleOrganizationUserIds = - await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId); + var policyEligibleOrganizationUserIds = await _policyRequirementQuery + .GetManyByOrganizationIdAsync(organizationId); var eligibleOrganizationUserIds = confirmedOrganizationUsers .Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id)) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs index 87c24c3ab4..c5c423f2bb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountCommand.cs @@ -1,4 +1,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs index 315d45ea69..71eff3ae69 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/DeleteClaimedOrganizationUserAccountValidator.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Repositories; -using static Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount.ValidationResultHelpers; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs index 6c8f7ee00c..a76104cc88 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/Errors.cs @@ -1,15 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2; -/// -/// A strongly typed error containing a reason that an action failed. -/// This is used for business logic validation and other expected errors, not exceptions. -/// -public abstract record Error(string Message); -/// -/// An type that maps to a NotFoundResult at the api layer. -/// -/// -public abstract record NotFoundError(string Message) : Error(Message); +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public record UserNotFoundError() : NotFoundError("Invalid user."); public record UserNotClaimedError() : Error("Member is not claimed by the organization."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs index 983a3a4f21..408d3e8bcd 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountCommand.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public interface IDeleteClaimedOrganizationUserAccountCommand { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs index f1a2c71b1b..05e97e896a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/IDeleteClaimedOrganizationUserAccountValidator.cs @@ -1,4 +1,6 @@ -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; public interface IDeleteClaimedOrganizationUserAccountValidator { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs new file mode 100644 index 0000000000..a1776416ae --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IAutomaticallyConfirmOrganizationUserCommand.cs @@ -0,0 +1,40 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; + +/// +/// Command to automatically confirm an organization user. +/// +/// +/// The auto-confirm feature enables eligible client apps to confirm OrganizationUsers +/// automatically via push notifications, eliminating the need for manual administrator +/// intervention. Client apps receive a push notification, perform the required key exchange, +/// and submit an auto-confirm request to the server. This command processes those +/// client-initiated requests and should only be used in that specific context. +/// +public interface IAutomaticallyConfirmOrganizationUserCommand +{ + /// + /// Automatically confirms the organization user based on the provided request data. + /// + /// The request containing necessary information to confirm the organization user. + /// + /// This action has side effects. The side effects are + ///
      + ///
    • Creating an event log entry.
    • + ///
    • Syncing organization keys with the user.
    • + ///
    • Deleting any registered user devices for the organization.
    • + ///
    • Sending an email to the confirmed user.
    • + ///
    • Creating the default collection if applicable.
    • + ///
    + /// + /// Each of these actions is performed independently of each other and not guaranteed to be performed in any order. + /// Errors will be reported back for the actions that failed in a consolidated error message. + ///
    + /// + /// The result of the command. If there was an error, the result will contain a typed error describing the problem + /// that occurred. + /// + Task AutomaticallyConfirmOrganizationUserAsync(AutomaticallyConfirmOrganizationUserRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs new file mode 100644 index 0000000000..c7c80bd937 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommand.cs @@ -0,0 +1,69 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.AdminConsole.Utilities.DebuggingInstruments; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +public class BulkResendOrganizationInvitesCommand : IBulkResendOrganizationInvitesCommand +{ + private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; + private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; + private readonly ILogger _logger; + + public BulkResendOrganizationInvitesCommand( + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, + ILogger logger) + { + _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; + _sendOrganizationInvitesCommand = sendOrganizationInvitesCommand; + _logger = logger; + } + + public async Task>> BulkResendInvitesAsync( + Guid organizationId, + Guid? invitingUserId, + IEnumerable organizationUsersId) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + _logger.LogUserInviteStateDiagnostics(orgUsers); + + var org = await _organizationRepository.GetByIdAsync(organizationId); + if (org == null) + { + throw new NotFoundException(); + } + + var validUsers = new List(); + var result = new List>(); + + foreach (var orgUser in orgUsers) + { + if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) + { + result.Add(Tuple.Create(orgUser, "User invalid.")); + } + else + { + validUsers.Add(orgUser); + } + } + + if (validUsers.Any()) + { + await _sendOrganizationInvitesCommand.SendInvitesAsync( + new SendInvitesRequest(validUsers, org)); + + result.AddRange(validUsers.Select(u => Tuple.Create(u, ""))); + } + + return result; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs new file mode 100644 index 0000000000..342a06fcf9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/IBulkResendOrganizationInvitesCommand.cs @@ -0,0 +1,20 @@ +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +public interface IBulkResendOrganizationInvitesCommand +{ + /// + /// Resend invites to multiple organization users in bulk. + /// + /// The ID of the organization. + /// The ID of the user who is resending the invites. + /// The IDs of the organization users to resend invites to. + /// A tuple containing the OrganizationUser and an error message (empty string if successful) + Task>> BulkResendInvitesAsync( + Guid organizationId, + Guid? invitingUserId, + IEnumerable organizationUsersId); +} + + diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs index f8bd988cab..2648a2e429 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUserValidator.cs @@ -2,10 +2,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; using Bit.Core.AdminConsole.Utilities.Errors; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; @@ -15,7 +15,7 @@ public class InviteOrganizationUsersValidator( IOrganizationRepository organizationRepository, IInviteUsersPasswordManagerValidator inviteUsersPasswordManagerValidator, IUpdateSecretsManagerSubscriptionCommand secretsManagerSubscriptionCommand, - IPaymentService paymentService) : IInviteUsersValidator + IStripePaymentService paymentService) : IInviteUsersValidator { public async Task> ValidateAsync( InviteOrganizationUsersValidationRequest request) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs index 67155fe91a..9ba2fd1596 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManager/InviteUsersPasswordManagerValidator.cs @@ -9,8 +9,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.Validation; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; @@ -22,7 +22,7 @@ public class InviteUsersPasswordManagerValidator( IInviteUsersEnvironmentValidator inviteUsersEnvironmentValidator, IInviteUsersOrganizationValidator inviteUsersOrganizationValidator, IProviderRepository providerRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationRepository organizationRepository ) : IInviteUsersPasswordManagerValidator { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index 651a9225b4..ec42c8b402 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -29,7 +30,8 @@ public class RestoreOrganizationUserCommand( IUserRepository userRepository, IOrganizationService organizationService, IFeatureService featureService, - IPolicyRequirementQuery policyRequirementQuery) : IRestoreOrganizationUserCommand + IPolicyRequirementQuery policyRequirementQuery, + IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) : IRestoreOrganizationUserCommand { public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId) { @@ -300,6 +302,25 @@ public class RestoreOrganizationUserCommand( { throw new BadRequestException(user.Email + " is not compliant with the two-step login policy"); } + + if (featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var validationResult = await automaticUserConfirmationPolicyEnforcementValidator.IsCompliantAsync( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.OrganizationId, + allOrgUsers, + user!)); + + var badRequestException = validationResult.Match( + error => new BadRequestException(user.Email + + " is not compliant with the automatic user confirmation policy: " + + error.Message), + _ => null); + + if (badRequestException is not null) + { + throw badRequestException; + } + } } private async Task IsTwoFactorRequiredForOrganizationAsync(Guid userId, Guid organizationId) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs similarity index 95% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs rename to src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs index 01ad2f05d2..7b5541c3ce 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/IRevokeOrganizationUserCommand.cs @@ -1,7 +1,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; public interface IRevokeOrganizationUserCommand { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs similarity index 99% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs rename to src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs index f24e0ae265..7aa67f0813 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v1/RevokeOrganizationUserCommand.cs @@ -7,7 +7,7 @@ using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; public class RevokeOrganizationUserCommand( IEventService eventService, diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs new file mode 100644 index 0000000000..a30894c7d5 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/Errors.cs @@ -0,0 +1,8 @@ +using Bit.Core.AdminConsole.Utilities.v2; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public record UserAlreadyRevoked() : BadRequestError("Already revoked."); +public record CannotRevokeYourself() : BadRequestError("You cannot revoke yourself."); +public record OnlyOwnersCanRevokeOwners() : BadRequestError("Only owners can revoke other owners."); +public record MustHaveConfirmedOwner() : BadRequestError("Organization must have at least one confirmed owner."); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs new file mode 100644 index 0000000000..e6471ad891 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserCommand.cs @@ -0,0 +1,8 @@ +using Bit.Core.AdminConsole.Utilities.v2.Results; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public interface IRevokeOrganizationUserCommand +{ + Task> RevokeUsersAsync(RevokeOrganizationUsersRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs new file mode 100644 index 0000000000..1a5cfd2c46 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/IRevokeOrganizationUserValidator.cs @@ -0,0 +1,9 @@ +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public interface IRevokeOrganizationUserValidator +{ + Task>> ValidateAsync(RevokeOrganizationUsersValidationRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs new file mode 100644 index 0000000000..ca501277a7 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommand.cs @@ -0,0 +1,114 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Utilities.v2.Results; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public class RevokeOrganizationUserCommand( + IOrganizationUserRepository organizationUserRepository, + IEventService eventService, + IPushNotificationService pushNotificationService, + IRevokeOrganizationUserValidator validator, + TimeProvider timeProvider, + ILogger logger) + : IRevokeOrganizationUserCommand +{ + public async Task> RevokeUsersAsync(RevokeOrganizationUsersRequest request) + { + var validationRequest = await CreateValidationRequestsAsync(request); + + var results = await validator.ValidateAsync(validationRequest); + + var validUsers = results.Where(r => r.IsValid).Select(r => r.Request).ToList(); + + await RevokeValidUsersAsync(validUsers); + + await Task.WhenAll( + LogRevokedOrganizationUsersAsync(validUsers, request.PerformedBy), + SendPushNotificationsAsync(validUsers) + ); + + return results.Select(r => r.Match( + error => new BulkCommandResult(r.Request.Id, error), + _ => new BulkCommandResult(r.Request.Id, new None()) + )); + } + + private async Task CreateValidationRequestsAsync( + RevokeOrganizationUsersRequest request) + { + var organizationUserToRevoke = await organizationUserRepository + .GetManyAsync(request.OrganizationUserIdsToRevoke); + + return new RevokeOrganizationUsersValidationRequest( + request.OrganizationId, + request.OrganizationUserIdsToRevoke, + request.PerformedBy, + organizationUserToRevoke); + } + + private async Task RevokeValidUsersAsync(ICollection validUsers) + { + if (validUsers.Count == 0) + { + return; + } + + await organizationUserRepository.RevokeManyByIdAsync(validUsers.Select(u => u.Id)); + } + + private async Task LogRevokedOrganizationUsersAsync( + ICollection revokedUsers, + IActingUser actingUser) + { + if (revokedUsers.Count == 0) + { + return; + } + + var eventDate = timeProvider.GetUtcNow().UtcDateTime; + + if (actingUser is SystemUser { SystemUserType: not null }) + { + var revokeEventsWithSystem = revokedUsers + .Select(user => (user, EventType.OrganizationUser_Revoked, actingUser.SystemUserType!.Value, + (DateTime?)eventDate)) + .ToList(); + await eventService.LogOrganizationUserEventsAsync(revokeEventsWithSystem); + } + else + { + var revokeEvents = revokedUsers + .Select(user => (user, EventType.OrganizationUser_Revoked, (DateTime?)eventDate)) + .ToList(); + await eventService.LogOrganizationUserEventsAsync(revokeEvents); + } + } + + private async Task SendPushNotificationsAsync(ICollection revokedUsers) + { + var userIdsToNotify = revokedUsers + .Where(user => user.UserId.HasValue) + .Select(user => user.UserId!.Value) + .Distinct() + .ToList(); + + foreach (var userId in userIdsToNotify) + { + try + { + await pushNotificationService.PushSyncOrgKeysAsync(userId); + } + catch (Exception ex) + { + logger.LogWarning(ex, "Failed to send push notification for user {UserId}.", userId); + } + } + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs new file mode 100644 index 0000000000..56996ffb53 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersRequest.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public record RevokeOrganizationUsersRequest( + Guid OrganizationId, + ICollection OrganizationUserIdsToRevoke, + IActingUser PerformedBy +); + +public record RevokeOrganizationUsersValidationRequest( + Guid OrganizationId, + ICollection OrganizationUserIdsToRevoke, + IActingUser PerformedBy, + ICollection OrganizationUsersToRevoke +) : RevokeOrganizationUsersRequest(OrganizationId, OrganizationUserIdsToRevoke, PerformedBy); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs new file mode 100644 index 0000000000..d2f47ed713 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidator.cs @@ -0,0 +1,39 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +public class RevokeOrganizationUsersValidator(IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery) + : IRevokeOrganizationUserValidator +{ + public async Task>> ValidateAsync( + RevokeOrganizationUsersValidationRequest request) + { + var hasRemainingOwner = await hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(request.OrganizationId, + request.OrganizationUsersToRevoke.Select(x => x.Id) // users excluded because they are going to be revoked + ); + + return request.OrganizationUsersToRevoke.Select(x => + { + return x switch + { + _ when request.PerformedBy is not SystemUser + && x.UserId is not null + && x.UserId == request.PerformedBy.UserId => + Invalid(x, new CannotRevokeYourself()), + { Status: OrganizationUserStatusType.Revoked } => + Invalid(x, new UserAlreadyRevoked()), + { Type: OrganizationUserType.Owner } when !hasRemainingOwner => + Invalid(x, new MustHaveConfirmedOwner()), + { Type: OrganizationUserType.Owner } when !request.PerformedBy.IsOrganizationOwnerOrProvider => + Invalid(x, new OnlyOwnersCanRevokeOwners()), + + _ => Valid(x) + }; + }).ToList(); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs index 8d8ab8cdfc..2aa09a5250 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/CloudOrganizationSignUpCommand.cs @@ -3,11 +3,14 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -33,7 +36,7 @@ public interface ICloudOrganizationSignUpCommand public class CloudOrganizationSignUpCommand( IOrganizationUserRepository organizationUserRepository, IOrganizationBillingService organizationBillingService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyService policyService, IOrganizationRepository organizationRepository, IOrganizationApiKeyRepository organizationApiKeyRepository, @@ -42,7 +45,9 @@ public class CloudOrganizationSignUpCommand( IPushNotificationService pushNotificationService, ICollectionRepository collectionRepository, IDeviceRepository deviceRepository, - IPricingClient pricingClient) : ICloudOrganizationSignUpCommand + IPricingClient pricingClient, + IPolicyRequirementQuery policyRequirementQuery, + IFeatureService featureService) : ICloudOrganizationSignUpCommand { public async Task SignUpOrganizationAsync(OrganizationSignup signup) { @@ -75,8 +80,7 @@ public class CloudOrganizationSignUpCommand( PlanType = plan!.Type, Seats = (short)(plan.PasswordManager.BaseSeats + signup.AdditionalSeats), MaxCollections = plan.PasswordManager.MaxCollections, - MaxStorageGb = !plan.PasswordManager.BaseStorageGb.HasValue ? - (short?)null : (short)(plan.PasswordManager.BaseStorageGb.Value + signup.AdditionalStorageGb), + MaxStorageGb = (short)(plan.PasswordManager.BaseStorageGb + signup.AdditionalStorageGb), UsePolicies = plan.HasPolicies, UseSso = plan.HasSso, UseGroups = plan.HasGroups, @@ -95,8 +99,8 @@ public class CloudOrganizationSignUpCommand( ReferenceData = signup.Owner.ReferenceData, Enabled = true, LicenseKey = CoreHelpers.SecureRandomString(20), - PublicKey = signup.PublicKey, - PrivateKey = signup.PrivateKey, + PublicKey = signup.Keys?.PublicKey, + PrivateKey = signup.Keys?.WrappedPrivateKey, CreationDate = DateTime.UtcNow, RevisionDate = DateTime.UtcNow, Status = OrganizationStatusType.Created, @@ -237,6 +241,17 @@ public class CloudOrganizationSignUpCommand( private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs index 6474914b48..da678ece71 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/InitPendingOrganizationCommand.cs @@ -2,6 +2,8 @@ #nullable disable using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Entities; @@ -28,6 +30,8 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand private readonly IGlobalSettings _globalSettings; private readonly IPolicyService _policyService; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IFeatureService _featureService; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public InitPendingOrganizationCommand( IOrganizationService organizationService, @@ -37,7 +41,9 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand IDataProtectionProvider dataProtectionProvider, IGlobalSettings globalSettings, IPolicyService policyService, - IOrganizationUserRepository organizationUserRepository + IOrganizationUserRepository organizationUserRepository, + IFeatureService featureService, + IPolicyRequirementQuery policyRequirementQuery ) { _organizationService = organizationService; @@ -48,6 +54,8 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand _globalSettings = globalSettings; _policyService = policyService; _organizationUserRepository = organizationUserRepository; + _featureService = featureService; + _policyRequirementQuery = policyRequirementQuery; } public async Task InitPendingOrganizationAsync(User user, Guid organizationId, Guid organizationUserId, string publicKey, string privateKey, string collectionName, string emailToken) @@ -113,6 +121,17 @@ public class InitPendingOrganizationCommand : IInitPendingOrganizationCommand private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await _policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs new file mode 100644 index 0000000000..85fbcd2740 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Interfaces/IOrganizationUpdateCommand.cs @@ -0,0 +1,15 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; + +public interface IOrganizationUpdateCommand +{ + /// + /// Updates an organization's information in the Bitwarden database and Stripe (if required). + /// Also optionally updates an organization's public-private keypair if it was not created with one. + /// On self-host, only the public-private keys will be updated because all other properties are fixed by the license file. + /// + /// The update request containing the details to be updated. + Task UpdateAsync(OrganizationUpdateRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs index 6a81130402..f73c49c811 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationDeleteCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,13 +13,13 @@ public class OrganizationDeleteCommand : IOrganizationDeleteCommand { private readonly IApplicationCacheService _applicationCacheService; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; public OrganizationDeleteCommand( IApplicationCacheService applicationCacheService, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ISsoConfigRepository ssoConfigRepository) { _applicationCacheService = applicationCacheService; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationExtensions.cs new file mode 100644 index 0000000000..bb8f985495 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/OrganizationExtensions.cs @@ -0,0 +1,28 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations; + +public static class OrganizationExtensions +{ + /// + /// Updates the organization public and private keys if provided and not already set. + /// This is legacy code for old organizations that were not created with a public/private keypair. + /// It is a soft migration that will silently migrate organizations when they perform certain actions, + /// e.g. change their details or upgrade their plan. + /// + public static void BackfillPublicPrivateKeys(this Organization organization, PublicKeyEncryptionKeyPairData? keyPair) + { + // Only backfill if both new keys are provided and both old keys are missing. + if (string.IsNullOrWhiteSpace(keyPair?.PublicKey) || + string.IsNullOrWhiteSpace(keyPair.WrappedPrivateKey) || + !string.IsNullOrWhiteSpace(organization.PublicKey) || + !string.IsNullOrWhiteSpace(organization.PrivateKey)) + { + return; + } + + organization.PublicKey = keyPair.PublicKey; + organization.PrivateKey = keyPair.WrappedPrivateKey; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs index 27e70fbe2d..c51ab2a5e0 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ProviderClientOrganizationSignUpCommand.cs @@ -73,7 +73,7 @@ public class ProviderClientOrganizationSignUpCommand : IProviderClientOrganizati PlanType = plan!.Type, Seats = signup.AdditionalSeats, MaxCollections = plan.PasswordManager.MaxCollections, - MaxStorageGb = 1, + MaxStorageGb = plan.PasswordManager.BaseStorageGb, UsePolicies = plan.HasPolicies, UseSso = plan.HasSso, UseOrganizationDomains = plan.HasOrganizationDomains, @@ -93,8 +93,8 @@ public class ProviderClientOrganizationSignUpCommand : IProviderClientOrganizati ReferenceData = signup.Owner.ReferenceData, Enabled = true, LicenseKey = CoreHelpers.SecureRandomString(20), - PublicKey = signup.PublicKey, - PrivateKey = signup.PrivateKey, + PublicKey = signup.Keys?.PublicKey, + PrivateKey = signup.Keys?.WrappedPrivateKey, CreationDate = DateTime.UtcNow, RevisionDate = DateTime.UtcNow, Status = OrganizationStatusType.Created, diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs index 446d7339ca..82260aa6a7 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/ResellerClientOrganizationSignUpCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -39,7 +40,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; private readonly ISendOrganizationInvitesCommand _sendOrganizationInvitesCommand; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; public ResellerClientOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -48,7 +49,7 @@ public class ResellerClientOrganizationSignUpCommand : IResellerClientOrganizati IOrganizationUserRepository organizationUserRepository, IEventService eventService, ISendOrganizationInvitesCommand sendOrganizationInvitesCommand, - IPaymentService paymentService) + IStripePaymentService paymentService) { _organizationRepository = organizationRepository; _organizationApiKeyRepository = organizationApiKeyRepository; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs index c52b7c10c9..9abce991c3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/SelfHostedOrganizationSignUpCommand.cs @@ -2,6 +2,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; @@ -30,7 +32,9 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp private readonly ILicensingService _licensingService; private readonly IPolicyService _policyService; private readonly IGlobalSettings _globalSettings; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; + private readonly IFeatureService _featureService; + private readonly IPolicyRequirementQuery _policyRequirementQuery; public SelfHostedOrganizationSignUpCommand( IOrganizationRepository organizationRepository, @@ -44,7 +48,9 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp ILicensingService licensingService, IPolicyService policyService, IGlobalSettings globalSettings, - IPaymentService paymentService) + IStripePaymentService paymentService, + IFeatureService featureService, + IPolicyRequirementQuery policyRequirementQuery) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -58,6 +64,8 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp _policyService = policyService; _globalSettings = globalSettings; _paymentService = paymentService; + _featureService = featureService; + _policyRequirementQuery = policyRequirementQuery; } public async Task<(Organization organization, OrganizationUser? organizationUser)> SignUpAsync( @@ -103,6 +111,17 @@ public class SelfHostedOrganizationSignUpCommand : ISelfHostedOrganizationSignUp private async Task ValidateSignUpPoliciesAsync(Guid ownerId) { + if (_featureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) + { + var requirement = await _policyRequirementQuery.GetAsync(ownerId); + + if (requirement.CannotCreateNewOrganization()) + { + throw new BadRequestException("You may not create an organization. You belong to an organization " + + "which has a policy that prohibits you from being a member of any other organization."); + } + } + var anySingleOrgPolicies = await _policyService.AnyPoliciesApplicableToUserAsync(ownerId, PolicyType.SingleOrg); if (anySingleOrgPolicies) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs new file mode 100644 index 0000000000..5cfd2191b3 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateCommand.cs @@ -0,0 +1,89 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +public class OrganizationUpdateCommand( + IOrganizationService organizationService, + IOrganizationRepository organizationRepository, + IGlobalSettings globalSettings, + IOrganizationBillingService organizationBillingService +) : IOrganizationUpdateCommand +{ + public async Task UpdateAsync(OrganizationUpdateRequest request) + { + var organization = await organizationRepository.GetByIdAsync(request.OrganizationId); + if (organization == null) + { + throw new NotFoundException(); + } + + if (globalSettings.SelfHosted) + { + return await UpdateSelfHostedAsync(organization, request); + } + + return await UpdateCloudAsync(organization, request); + } + + private async Task UpdateCloudAsync(Organization organization, OrganizationUpdateRequest request) + { + // Store original values for comparison + var originalName = organization.Name; + var originalBillingEmail = organization.BillingEmail; + + // Apply updates to organization + // These values may or may not be sent by the client depending on the operation being performed. + // Skip any values not provided. + if (request.Name is not null) + { + organization.Name = request.Name; + } + + if (request.BillingEmail is not null) + { + organization.BillingEmail = request.BillingEmail.ToLowerInvariant().Trim(); + } + + organization.BackfillPublicPrivateKeys(request.Keys); + + await organizationService.ReplaceAndUpdateCacheAsync(organization, EventType.Organization_Updated); + + // Update billing information in Stripe if required + await UpdateBillingAsync(organization, originalName, originalBillingEmail); + + return organization; + } + + /// + /// Self-host cannot update the organization details because they are set by the license file. + /// However, this command does offer a soft migration pathway for organizations without public and private keys. + /// If we remove this migration code in the future, this command and endpoint can become cloud only. + /// + private async Task UpdateSelfHostedAsync(Organization organization, OrganizationUpdateRequest request) + { + organization.BackfillPublicPrivateKeys(request.Keys); + await organizationService.ReplaceAndUpdateCacheAsync(organization, EventType.Organization_Updated); + return organization; + } + + private async Task UpdateBillingAsync(Organization organization, string originalName, string? originalBillingEmail) + { + // Update Stripe if name or billing email changed + var shouldUpdateBilling = originalName != organization.Name || + originalBillingEmail != organization.BillingEmail; + + if (!shouldUpdateBilling) + { + return; + } + + await organizationBillingService.UpdateOrganizationNameAndEmail(organization); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs new file mode 100644 index 0000000000..4695ee0ba7 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/Update/OrganizationUpdateRequest.cs @@ -0,0 +1,30 @@ +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; + +/// +/// Request model for updating the name, billing email, and/or public-private keys for an organization (legacy migration code). +/// Any combination of these properties can be updated, so they are optional. If none are specified it will not update anything. +/// +public record OrganizationUpdateRequest +{ + /// + /// The ID of the organization to update. + /// + public required Guid OrganizationId { get; init; } + + /// + /// The new organization name to apply (optional, this is skipped if not provided). + /// + public string? Name { get; init; } + + /// + /// The new billing email address to apply (optional, this is skipped if not provided). + /// + public string? BillingEmail { get; init; } + + /// + /// The organization's public/private key pair to set (optional, only set if not already present on the organization). + /// + public PublicKeyEncryptionKeyPairData? Keys { get; init; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs index 450f425bdf..e4d5a94c4c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommand.cs @@ -1,12 +1,12 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using Microsoft.Extensions.Logging; namespace Bit.Core.AdminConsole.OrganizationFeatures.Organizations; -public class UpdateOrganizationSubscriptionCommand(IPaymentService paymentService, +public class UpdateOrganizationSubscriptionCommand(IStripePaymentService paymentService, IOrganizationRepository repository, TimeProvider timeProvider, ILogger logger) : IUpdateOrganizationSubscriptionCommand diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs new file mode 100644 index 0000000000..962da4bef7 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementRequest.cs @@ -0,0 +1,44 @@ +using Bit.Core.Entities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +/// +/// Request object for +/// +public record AutomaticUserConfirmationPolicyEnforcementRequest +{ + /// + /// Organization to be validated + /// + public Guid OrganizationId { get; } + + /// + /// All organization users that match the provided user. + /// + public ICollection AllOrganizationUsers { get; } + + /// + /// User associated with the organization user to be confirmed + /// + public User User { get; } + + /// + /// Request object for . + /// + /// + /// This record is used to encapsulate the data required for handling the automatic confirmation policy enforcement. + /// + /// The organization to be validated. + /// All organization users that match the provided user. + /// The user entity connecting all org users provided. + public AutomaticUserConfirmationPolicyEnforcementRequest( + Guid organizationId, + IEnumerable organizationUsers, + User user) + { + OrganizationId = organizationId; + AllOrganizationUsers = organizationUsers.ToArray(); + User = user; + } +} + diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs new file mode 100644 index 0000000000..e5c980ea24 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidator.cs @@ -0,0 +1,50 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +public class AutomaticUserConfirmationPolicyEnforcementValidator( + IPolicyRequirementQuery policyRequirementQuery, + IProviderUserRepository providerUserRepository) + : IAutomaticUserConfirmationPolicyEnforcementValidator +{ + public async Task> IsCompliantAsync( + AutomaticUserConfirmationPolicyEnforcementRequest request) + { + var automaticUserConfirmationPolicyRequirement = await policyRequirementQuery + .GetAsync(request.User.Id); + + var currentOrganizationUser = request.AllOrganizationUsers + .FirstOrDefault(x => x.OrganizationId == request.OrganizationId + // invited users do not have a userId but will have email + && (x.UserId == request.User.Id || x.Email == request.User.Email)); + + if (currentOrganizationUser is null) + { + return Invalid(request, new CurrentOrganizationUserIsNotPresentInRequest()); + } + + if (automaticUserConfirmationPolicyRequirement.IsEnabled(request.OrganizationId)) + { + if ((await providerUserRepository.GetManyByUserAsync(request.User.Id)).Count != 0) + { + return Invalid(request, new ProviderUsersCannotJoin()); + } + + if (request.AllOrganizationUsers.Count > 1) + { + return Invalid(request, new UserCannotBelongToAnotherOrganization()); + } + } + + if (automaticUserConfirmationPolicyRequirement.IsEnabledForOrganizationsOtherThan(currentOrganizationUser.OrganizationId)) + { + return Invalid(request, new OtherOrganizationDoesNotAllowOtherMembership()); + } + + return Valid(request); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs new file mode 100644 index 0000000000..7bc1664140 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/IAutomaticUserConfirmationPolicyEnforcementValidator.cs @@ -0,0 +1,28 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Utilities.v2.Validation; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +/// +/// Used to enforce the Automatic User Confirmation policy. It uses the to retrieve +/// the . It is used to check to make sure the given user is +/// valid for the Automatic User Confirmation policy. It also validates that the given user is not a provider +/// or a member of another organization regardless of status or type. +/// +public interface IAutomaticUserConfirmationPolicyEnforcementValidator +{ + + /// + /// Checks if the given user is compliant with the Automatic User Confirmation policy. + /// + /// To be compliant, a user must + /// - not be a member of a provider + /// - not be a member of another organization + /// + /// + /// + /// This uses the validation result pattern to avoid throwing exceptions. + /// + /// A validation result with the error message if applicable. + Task> IsCompliantAsync(AutomaticUserConfirmationPolicyEnforcementRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs index 6aef9f248b..d3df63b6ac 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs @@ -9,6 +9,10 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; /// /// Defines behavior and functionality for a given PolicyType. /// +/// +/// All methods defined in this interface are for the PolicyService#SavePolicy method. This needs to be supported until +/// we successfully refactor policy validators over to policy validation handlers +/// public interface IPolicyValidator { /// diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index e846e02e46..c1450c6ab5 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; @@ -20,7 +18,7 @@ public class PolicyRequirementQuery( throw new NotImplementedException("No Requirement Factory found for " + typeof(T)); } - var policyDetails = await GetPolicyDetails(userId); + var policyDetails = await GetPolicyDetails(userId, factory.PolicyType); var filteredPolicies = policyDetails .Where(p => p.PolicyType == factory.PolicyType) .Where(factory.Enforce); @@ -48,8 +46,8 @@ public class PolicyRequirementQuery( return eligibleOrganizationUserIds; } - private Task> GetPolicyDetails(Guid userId) - => policyRepository.GetPolicyDetailsByUserId(userId); + private async Task> GetPolicyDetails(Guid userId, PolicyType policyType) + => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType); private async Task> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType) => await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs index e2bca930d1..57140317e3 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs @@ -4,6 +4,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; @@ -16,19 +18,22 @@ public class SavePolicyCommand : ISavePolicyCommand private readonly IReadOnlyDictionary _policyValidators; private readonly TimeProvider _timeProvider; private readonly IPostSavePolicySideEffect _postSavePolicySideEffect; + private readonly IPushNotificationService _pushNotificationService; public SavePolicyCommand(IApplicationCacheService applicationCacheService, IEventService eventService, IPolicyRepository policyRepository, IEnumerable policyValidators, TimeProvider timeProvider, - IPostSavePolicySideEffect postSavePolicySideEffect) + IPostSavePolicySideEffect postSavePolicySideEffect, + IPushNotificationService pushNotificationService) { _applicationCacheService = applicationCacheService; _eventService = eventService; _policyRepository = policyRepository; _timeProvider = timeProvider; _postSavePolicySideEffect = postSavePolicySideEffect; + _pushNotificationService = pushNotificationService; var policyValidatorsDict = new Dictionary(); foreach (var policyValidator in policyValidators) @@ -75,6 +80,8 @@ public class SavePolicyCommand : ISavePolicyCommand await _policyRepository.UpsertAsync(policy); await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + await PushPolicyUpdateToClients(policy.OrganizationId, policy); + return policy; } @@ -152,4 +159,17 @@ public class SavePolicyCommand : ISavePolicyCommand var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdate.Type); return (savedPoliciesDict, currentPolicy); } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => this._pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs new file mode 100644 index 0000000000..38e417d085 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs @@ -0,0 +1,211 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models; +using Bit.Core.Platform.Push; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; + +public class VNextSavePolicyCommand( + IApplicationCacheService applicationCacheService, + IEventService eventService, + IPolicyRepository policyRepository, + IEnumerable policyUpdateEventHandlers, + TimeProvider timeProvider, + IPolicyEventHandlerFactory policyEventHandlerFactory, + IPushNotificationService pushNotificationService) + : IVNextSavePolicyCommand +{ + + public async Task SaveAsync(SavePolicyModel policyRequest) + { + var policyUpdateRequest = policyRequest.PolicyUpdate; + var organizationId = policyUpdateRequest.OrganizationId; + + await EnsureOrganizationCanUsePolicyAsync(organizationId); + + var savedPoliciesDict = await GetCurrentPolicyStateAsync(organizationId); + + var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdateRequest.Type); + + ValidatePolicyDependencies(policyUpdateRequest, currentPolicy, savedPoliciesDict); + + await ValidateTargetedPolicyAsync(policyRequest, currentPolicy); + + await ExecutePreUpsertSideEffectAsync(policyRequest, currentPolicy); + + var upsertedPolicy = await UpsertPolicyAsync(policyUpdateRequest); + + await eventService.LogPolicyEventAsync(upsertedPolicy, EventType.Policy_Updated); + + await ExecutePostUpsertSideEffectAsync(policyRequest, upsertedPolicy, currentPolicy); + + return upsertedPolicy; + } + + private async Task EnsureOrganizationCanUsePolicyAsync(Guid organizationId) + { + var org = await applicationCacheService.GetOrganizationAbilityAsync(organizationId); + if (org == null) + { + throw new BadRequestException("Organization not found"); + } + + if (!org.UsePolicies) + { + throw new BadRequestException("This organization cannot use policies."); + } + } + + private async Task UpsertPolicyAsync(PolicyUpdate policyUpdateRequest) + { + var policy = await policyRepository.GetByOrganizationIdTypeAsync(policyUpdateRequest.OrganizationId, policyUpdateRequest.Type) + ?? new Policy + { + OrganizationId = policyUpdateRequest.OrganizationId, + Type = policyUpdateRequest.Type, + CreationDate = timeProvider.GetUtcNow().UtcDateTime + }; + + policy.Enabled = policyUpdateRequest.Enabled; + policy.Data = policyUpdateRequest.Data; + policy.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; + + await policyRepository.UpsertAsync(policy); + await PushPolicyUpdateToClients(policyUpdateRequest.OrganizationId, policy); + return policy; + } + + private async Task ValidateTargetedPolicyAsync(SavePolicyModel policyRequest, + Policy? currentPolicy) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + async validator => + { + var validationError = await validator.ValidateAsync(policyRequest, currentPolicy); + if (!string.IsNullOrEmpty(validationError)) + { + throw new BadRequestException(validationError); + } + }); + } + + private void ValidatePolicyDependencies( + PolicyUpdate policyUpdateRequest, + Policy? currentPolicy, + Dictionary savedPoliciesDict) + { + var isCurrentlyEnabled = currentPolicy?.Enabled == true; + var isBeingEnabled = policyUpdateRequest.Enabled && !isCurrentlyEnabled; + var isBeingDisabled = !policyUpdateRequest.Enabled && isCurrentlyEnabled; + + if (isBeingEnabled) + { + ValidateEnablingRequirements(policyUpdateRequest.Type, savedPoliciesDict); + } + else if (isBeingDisabled) + { + ValidateDisablingRequirements(policyUpdateRequest.Type, savedPoliciesDict); + } + } + + private void ValidateDisablingRequirements( + PolicyType policyType, + Dictionary savedPoliciesDict) + { + var dependentPolicyTypes = policyUpdateEventHandlers + .OfType() + .Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyType)) + .Select(otherValidator => otherValidator.Type) + .Where(otherPolicyType => savedPoliciesDict.TryGetValue(otherPolicyType, out var savedPolicy) && + savedPolicy.Enabled) + .ToList(); + + switch (dependentPolicyTypes) + { + case { Count: 1 }: + throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {policyType.GetName()} policy."); + case { Count: > 1 }: + throw new BadRequestException($"Turn off all of the policies that require the {policyType.GetName()} policy."); + } + } + + private void ValidateEnablingRequirements( + PolicyType policyType, + Dictionary savedPoliciesDict) + { + var result = policyEventHandlerFactory.GetHandler(policyType); + + result.Switch( + validator => + { + var missingRequiredPolicyTypes = validator.RequiredPolicies + .Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true }) + .ToList(); + + if (missingRequiredPolicyTypes.Count != 0) + { + throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {policyType.GetName()} policy."); + } + }, + _ => { /* Policy has no required dependencies */ }); + } + + private async Task ExecutePreUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + handler => handler.ExecutePreUpsertSideEffectAsync(policyRequest, currentPolicy)); + } + private async Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + handler => handler.ExecutePostUpsertSideEffectAsync( + policyRequest, + postUpsertedPolicyState, + previousPolicyState)); + } + + private async Task ExecutePolicyEventAsync(PolicyType type, Func func) where T : IPolicyUpdateEvent + { + var handler = policyEventHandlerFactory.GetHandler(type); + + await handler.Match( + async h => await func(h), + _ => Task.CompletedTask + ); + } + + private async Task> GetCurrentPolicyStateAsync(Guid organizationId) + { + var savedPolicies = await policyRepository.GetManyByOrganizationIdAsync(organizationId); + // Note: policies may be missing from this dict if they have never been enabled + var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type); + return savedPoliciesDict; + } + + Task PushPolicyUpdateToClients(Guid organizationId, Policy policy) => pushNotificationService.PushAsync(new PushNotification + { + Type = PushType.PolicyChanged, + Target = NotificationTarget.Organization, + TargetId = organizationId, + ExcludeCurrentContext = false, + Payload = new SyncPolicyPushNotification + { + Policy = policy, + OrganizationId = organizationId + } + }); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs index d1a52f0080..cad786234c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs @@ -16,6 +16,8 @@ public record PolicyUpdate public PolicyType Type { get; set; } public string? Data { get; set; } public bool Enabled { get; set; } + + [Obsolete("Please use SavePolicyModel.PerformedBy instead.")] public IActingUser? PerformedBy { get; set; } public T GetDataModel() where T : IPolicyDataModel, new() diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs index 7c8d5126e8..01168deea4 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs @@ -5,4 +5,18 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; public record SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser? PerformedBy, IPolicyMetadataModel Metadata) { + public SavePolicyModel(PolicyUpdate PolicyUpdate) + : this(PolicyUpdate, null, new EmptyMetadataModel()) + { + } + + public SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser performedBy) + : this(PolicyUpdate, performedBy, new EmptyMetadataModel()) + { + } + + public SavePolicyModel(PolicyUpdate PolicyUpdate, IPolicyMetadataModel metadata) + : this(PolicyUpdate, null, metadata) + { + } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs new file mode 100644 index 0000000000..3430f33a77 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs @@ -0,0 +1,48 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; + +/// +/// Represents the enforcement status of the Automatic User Confirmation policy. +/// +/// +/// The Automatic User Confirmation policy is enforced against all types of users regardless of status or type. +/// +/// Users cannot: +///
      +///
    • Be a member of another organization (similar to Single Organization Policy)
    • +///
    • Cannot be a provider
    • +///
    +///
    +/// Collection of policy details that apply to this user id +public class AutomaticUserConfirmationPolicyRequirement(IEnumerable policyDetails) : IPolicyRequirement +{ + public bool CannotBeGrantedEmergencyAccess() => policyDetails.Any(); + + public bool CannotJoinProvider() => policyDetails.Any(); + + public bool CannotCreateProvider() => policyDetails.Any(); + + public bool CannotCreateNewOrganization() => policyDetails.Any(); + + public bool IsEnabled(Guid organizationId) => policyDetails.Any(p => p.OrganizationId == organizationId); + + public bool IsEnabledForOrganizationsOtherThan(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId != organizationId); +} + +public class AutomaticUserConfirmationPolicyRequirementFactory : BasePolicyRequirementFactory +{ + public override PolicyType PolicyType => PolicyType.AutomaticUserConfirmation; + + protected override IEnumerable ExemptRoles => []; + + protected override IEnumerable ExemptStatuses => []; + + protected override bool ExemptProviders => false; + + public override AutomaticUserConfirmationPolicyRequirement Create(IEnumerable policyDetails) => + new(policyDetails); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs index 28d6614dcb..c9653053ea 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs @@ -72,6 +72,17 @@ public class OrganizationDataOwnershipPolicyRequirement : IPolicyRequirement { return _policyDetails.Any(p => p.OrganizationId == organizationId); } + + /// + /// Ignore storage limits if the organization has data ownership policy enabled. + /// Allows users to seamlessly migrate their data into the organization without being blocked by storage limits. + /// Organization admins will need to manage storage after migration should overages occur. + /// + public bool IgnoreStorageLimitsOnMigration(Guid organizationId) + { + return _policyDetails.Any(p => p.OrganizationId == organizationId && + p.OrganizationUserStatus == OrganizationUserStatusType.Confirmed); + } } public record DefaultCollectionRequest(Guid OrganizationUserId, bool ShouldCreateDefaultCollection) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs new file mode 100644 index 0000000000..d1e1efafd9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/SingleOrganizationPolicyRequirement.cs @@ -0,0 +1,21 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; + +public class SingleOrganizationPolicyRequirement(IEnumerable policyDetails) : IPolicyRequirement +{ + public bool IsSingleOrgEnabledForThisOrganization(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId == organizationId); + + public bool IsSingleOrgEnabledForOrganizationsOtherThan(Guid organizationId) => + policyDetails.Any(p => p.OrganizationId != organizationId); +} + +public class SingleOrganizationPolicyRequirementFactory : BasePolicyRequirementFactory +{ + public override PolicyType PolicyType => PolicyType.SingleOrg; + + public override SingleOrganizationPolicyRequirement Create(IEnumerable policyDetails) => + new(policyDetails); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 5433d70410..f69935715d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -1,5 +1,8 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; @@ -13,13 +16,19 @@ public static class PolicyServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddPolicyValidators(); services.AddPolicyRequirements(); services.AddPolicySideEffects(); + services.AddPolicyUpdateEvents(); + + services.AddScoped(); } + [Obsolete("Use AddPolicyUpdateEvents instead.")] private static void AddPolicyValidators(this IServiceCollection services) { services.AddScoped(); @@ -27,14 +36,32 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); + services.AddScoped(); + services.AddScoped(); } + [Obsolete("Use AddPolicyUpdateEvents instead.")] private static void AddPolicySideEffects(this IServiceCollection services) { services.AddScoped(); } + private static void AddPolicyUpdateEvents(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + } + private static void AddPolicyRequirements(this IServiceCollection services) { services.AddScoped, DisableSendPolicyRequirementFactory>(); @@ -44,5 +71,7 @@ public static class PolicyServiceCollectionExtensions services.AddScoped, RequireSsoPolicyRequirementFactory>(); services.AddScoped, RequireTwoFactorPolicyRequirementFactory>(); services.AddScoped, MasterPasswordPolicyRequirementFactory>(); + services.AddScoped, SingleOrganizationPolicyRequirementFactory>(); + services.AddScoped, AutomaticUserConfirmationPolicyRequirementFactory>(); } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs new file mode 100644 index 0000000000..0e2bdc3d69 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs @@ -0,0 +1,19 @@ +using Bit.Core.AdminConsole.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all policies required to be enabled before the given policy can be enabled. +/// +/// +/// This interface is intended for policy event handlers that mandate the activation of other policies +/// as prerequisites for enabling the associated policy. +/// +public interface IEnforceDependentPoliciesEvent : IPolicyUpdateEvent +{ + /// + /// PolicyTypes that must be enabled before this policy can be enabled, if any. + /// These dependencies will be checked when this policy is enabled and when any required policy is disabled. + /// + public IEnumerable RequiredPolicies { get; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs new file mode 100644 index 0000000000..08295bf7fb --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs @@ -0,0 +1,18 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +public interface IOnPolicyPostUpdateEvent : IPolicyUpdateEvent +{ + /// + /// Performs side effects after a policy has been upserted. + /// For example, this can be used for cleanup tasks or notifications. + /// + /// The policy save request + /// The policy after it was upserted + /// The policy state before it was updated, if any + public Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs new file mode 100644 index 0000000000..4167a392e4 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs @@ -0,0 +1,23 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all side effects that should be executed before a policy is upserted. +/// +/// +/// This should be added to policy handlers that need to perform side effects before policy upserts. +/// +public interface IOnPolicyPreUpdateEvent : IPolicyUpdateEvent +{ + /// + /// Performs side effects before a policy is upserted. + /// For example, this can be used to remove non-compliant users from the organization. + /// + /// The policy save request containing the policy update and metadata + /// The current policy, if any + public Task ExecutePreUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs new file mode 100644 index 0000000000..f44ae867dd --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs @@ -0,0 +1,30 @@ +#nullable enable + +using Bit.Core.AdminConsole.Enums; +using OneOf; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Provides policy-specific event handlers used during the save workflow in . +/// +/// +/// Supported handlers: +/// - for dependency checks +/// - for custom validation +/// - for pre-save logic +/// - for post-save logic +/// +public interface IPolicyEventHandlerFactory +{ + /// + /// Gets the event handler for the given policy type and handler interface. + /// + /// Handler type implementing . + /// The policy type to resolve. + /// + /// — the handler if available, or None if not implemented. + /// + OneOf GetHandler(PolicyType policyType) where T : IPolicyUpdateEvent; +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs new file mode 100644 index 0000000000..a568658d4d --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents the policy to be upserted. +/// +/// +/// This is used for the VNextSavePolicyCommand. All policy handlers should implement this interface. +/// +public interface IPolicyUpdateEvent +{ + /// + /// The policy type that the associated handler will handle. + /// + public PolicyType Type { get; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs new file mode 100644 index 0000000000..ee401ef813 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs @@ -0,0 +1,24 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all validations that need to be run to enable or disable the given policy. +/// +/// +/// This is used for the VNextSavePolicyCommand. This optional but should be implemented for all policies that have +/// certain requirements for the given organization. +/// +public interface IPolicyValidationEvent : IPolicyUpdateEvent +{ + /// + /// Performs any validations required to enable or disable the policy. + /// + /// The policy save request containing the policy update and metadata + /// The current policy, if any + public Task ValidateAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy); + +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs new file mode 100644 index 0000000000..93414539bb --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs @@ -0,0 +1,34 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Microsoft.Azure.NotificationHubs.Messaging; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Handles creating or updating organization policies with validation and side effect execution. +/// +/// +/// Workflow: +/// 1. Validates organization can use policies +/// 2. Validates required and dependent policies +/// 3. Runs policy-specific validation () +/// 4. Executes pre-save logic () +/// 5. Saves the policy +/// 6. Logs the event +/// 7. Executes post-save logic () +/// +public interface IVNextSavePolicyCommand +{ + /// + /// Performs the necessary validations, saves the policy and any side effects + /// + /// Policy data, acting user, and metadata. + /// The saved policy with updated revision and applied changes. + /// + /// Thrown if: + /// - The organization can’t use policies + /// - Dependent policies are missing or block changes + /// - Custom validation fails + /// + Task SaveAsync(SavePolicyModel policyRequest); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs new file mode 100644 index 0000000000..b1abfb2aaf --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs @@ -0,0 +1,33 @@ + +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using OneOf; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; + +public class PolicyEventHandlerHandlerFactory( + IEnumerable allEventHandlers) : IPolicyEventHandlerFactory +{ + public OneOf GetHandler(PolicyType policyType) where T : IPolicyUpdateEvent + { + var tEventHandlers = allEventHandlers.OfType().ToList(); + + var matchingHandlers = tEventHandlers.Where(h => h.Type == policyType).ToList(); + + if (matchingHandlers.Count > 1) + { + throw new InvalidOperationException( + $"Multiple {nameof(IPolicyUpdateEvent)} handlers of type {typeof(T).Name} found for {nameof(PolicyType)} {policyType}. " + + $"Expected one {typeof(T).Name} handler per {nameof(PolicyType)}."); + } + + var policyTEventHandler = matchingHandlers.SingleOrDefault(); + if (policyTEventHandler is null) + { + return new None(); + } + + return policyTEventHandler; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs new file mode 100644 index 0000000000..86c94147f4 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs @@ -0,0 +1,94 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +/// +/// Represents an event handler for the Automatic User Confirmation policy. +/// +/// This class validates that the following conditions are met: +///
      +///
    • The Single organization policy is enabled
    • +///
    • All organization users are compliant with the Single organization policy
    • +///
    • No provider users exist
    • +///
    +///
    +public class AutomaticUserConfirmationPolicyEventHandler( + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository) + : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.AutomaticUserConfirmation; + + private const string _usersNotCompliantWithSingleOrgErrorMessage = + "All organization users must be compliant with the Single organization policy before enabling the Automatically confirm invited users policy. Please remove users who are members of multiple organizations."; + + private const string _providerUsersExistErrorMessage = + "The organization has users with the Provider user type. Please remove provider users before enabling the Automatically confirm invited users policy."; + + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + var isNotEnablingPolicy = policyUpdate is not { Enabled: true }; + var policyAlreadyEnabled = currentPolicy is { Enabled: true }; + if (isNotEnablingPolicy || policyAlreadyEnabled) + { + return string.Empty; + } + + return await ValidateEnablingPolicyAsync(policyUpdate.OrganizationId); + } + + public async Task ValidateAsync(SavePolicyModel savePolicyModel, Policy? currentPolicy) => + await ValidateAsync(savePolicyModel.PolicyUpdate, currentPolicy); + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => + Task.CompletedTask; + + private async Task ValidateEnablingPolicyAsync(Guid organizationId) + { + var organizationUsers = await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId); + + var singleOrgValidationError = await ValidateUserComplianceWithSingleOrgAsync(organizationId, organizationUsers); + if (!string.IsNullOrWhiteSpace(singleOrgValidationError)) + { + return singleOrgValidationError; + } + + var providerValidationError = await ValidateNoProviderUsersAsync(organizationUsers); + if (!string.IsNullOrWhiteSpace(providerValidationError)) + { + return providerValidationError; + } + + return string.Empty; + } + + private async Task ValidateUserComplianceWithSingleOrgAsync(Guid organizationId, + ICollection organizationUsers) + { + var hasNonCompliantUser = (await organizationUserRepository.GetManyByManyUsersAsync( + organizationUsers.Select(ou => ou.UserId!.Value))) + .Any(uo => uo.OrganizationId != organizationId + && uo.Status != OrganizationUserStatusType.Invited); + + return hasNonCompliantUser ? _usersNotCompliantWithSingleOrgErrorMessage : string.Empty; + } + + private async Task ValidateNoProviderUsersAsync(ICollection organizationUsers) + { + var userIds = organizationUsers.Where(x => x.UserId is not null) + .Select(x => x.UserId!.Value); + + return (await providerUserRepository.GetManyByManyUsersAsync(userIds)).Count != 0 + ? _providerUsersExistErrorMessage + : string.Empty; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs new file mode 100644 index 0000000000..92ba11f5a6 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidator.cs @@ -0,0 +1,59 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class BlockClaimedDomainAccountCreationPolicyValidator : IPolicyValidator, IPolicyValidationEvent +{ + private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; + private readonly IFeatureService _featureService; + + public BlockClaimedDomainAccountCreationPolicyValidator( + IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, + IFeatureService featureService) + { + _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _featureService = featureService; + } + + public PolicyType Type => PolicyType.BlockClaimedDomainAccountCreation; + + // No prerequisites - this policy stands alone + public IEnumerable RequiredPolicies => []; + + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + // Check if feature is enabled + if (!_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + return "This feature is not enabled"; + } + + // Only validate when trying to ENABLE the policy + if (policyUpdate is { Enabled: true }) + { + // Check if organization has at least one verified domain + if (!await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policyUpdate.OrganizationId)) + { + return "You must claim at least one domain to turn on this policy"; + } + } + + // Disabling the policy is always allowed + return string.Empty; + } + + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + => Task.CompletedTask; +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs index 57db4962e3..52a7e3e880 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,11 +13,16 @@ public class FreeFamiliesForEnterprisePolicyValidator( IOrganizationSponsorshipRepository organizationSponsorshipRepository, IMailService mailService, IOrganizationRepository organizationRepository) - : IPolicyValidator + : IPolicyValidator, IOnPolicyPreUpdateEvent { public PolicyType Type => PolicyType.FreeFamiliesSponsorshipPolicy; public IEnumerable RequiredPolicies => []; + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs index bfd4dcfe0d..796ed286d8 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs @@ -3,10 +3,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator +public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent { public PolicyType Type => PolicyType.MaximumVaultTimeout; public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index f4ef6021a7..0bee2a55af 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -1,24 +1,32 @@  using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Repositories; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -/// -/// Please do not extend or expand this validator. We're currently in the process of refactoring our policy validator pattern. -/// This is a stop-gap solution for post-policy-save side effects, but it is not the long-term solution. -/// public class OrganizationDataOwnershipPolicyValidator( IPolicyRepository policyRepository, ICollectionRepository collectionRepository, IEnumerable> factories, IFeatureService featureService) - : OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect + : OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect, IOnPolicyPostUpdateEvent { + public PolicyType Type => PolicyType.OrganizationDataOwnership; + + public async Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState) + { + await ExecuteSideEffectsAsync(policyRequest, postUpsertedPolicyState, previousPolicyState); + } + public async Task ExecuteSideEffectsAsync( SavePolicyModel policyRequest, Policy postUpdatedPolicy, @@ -68,5 +76,4 @@ public class OrganizationDataOwnershipPolicyValidator( userOrgIds, defaultCollectionName); } - } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs index 2082d4305f..adc2a3865a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs @@ -3,12 +3,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class RequireSsoPolicyValidator : IPolicyValidator +public class RequireSsoPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent { private readonly ISsoConfigRepository _ssoConfigRepository; @@ -20,6 +21,11 @@ public class RequireSsoPolicyValidator : IPolicyValidator public PolicyType Type => PolicyType.RequireSso; public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (policyUpdate is not { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs index 1126c4b922..9033a38ad0 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs @@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class ResetPasswordPolicyValidator : IPolicyValidator +public class ResetPasswordPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent { private readonly ISsoConfigRepository _ssoConfigRepository; public PolicyType Type => PolicyType.ResetPassword; @@ -20,6 +21,11 @@ public class ResetPasswordPolicyValidator : IPolicyValidator _ssoConfigRepository = ssoConfigRepository; } + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (policyUpdate is not { Enabled: true } || diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs index 49467eaae4..d24c61e258 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs @@ -1,12 +1,11 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Context; @@ -17,7 +16,7 @@ using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class SingleOrgPolicyValidator : IPolicyValidator +public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent { public PolicyType Type => PolicyType.SingleOrg; private const string OrganizationNotFoundErrorMessage = "Organization not found."; @@ -28,8 +27,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator private readonly IOrganizationRepository _organizationRepository; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly ICurrentContext _currentContext; - private readonly IFeatureService _featureService; - private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand; @@ -39,8 +36,6 @@ public class SingleOrgPolicyValidator : IPolicyValidator IOrganizationRepository organizationRepository, ISsoConfigRepository ssoConfigRepository, ICurrentContext currentContext, - IFeatureService featureService, - IRemoveOrganizationUserCommand removeOrganizationUserCommand, IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand) { @@ -49,14 +44,22 @@ public class SingleOrgPolicyValidator : IPolicyValidator _organizationRepository = organizationRepository; _ssoConfigRepository = ssoConfigRepository; _currentContext = currentContext; - _featureService = featureService; - _removeOrganizationUserCommand = removeOrganizationUserCommand; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } public IEnumerable RequiredPolicies => []; + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs index 5ce72df6c1..7f3ebcccfb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Context; using Bit.Core.Enums; @@ -16,7 +17,7 @@ using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator +public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator, IOnPolicyPreUpdateEvent { private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IMailService _mailService; @@ -46,6 +47,11 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs new file mode 100644 index 0000000000..5bffd944c9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs @@ -0,0 +1,14 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class UriMatchDefaultPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.UriMatchDefaults; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.CompletedTask; +} diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs deleted file mode 100644 index 0a774cf395..0000000000 --- a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationConfigurationRepository.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; - -namespace Bit.Core.Repositories; - -public interface IOrganizationIntegrationConfigurationRepository : IRepository -{ - Task> GetConfigurationDetailsAsync( - Guid organizationId, - IntegrationType integrationType, - EventType eventType); - - Task> GetAllConfigurationDetailsAsync(); - - Task> GetManyByIntegrationAsync(Guid organizationIntegrationId); -} diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs deleted file mode 100644 index 434c8ddee3..0000000000 --- a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Core.Repositories; - -public interface IOrganizationIntegrationRepository : IRepository -{ - Task> GetManyByOrganizationAsync(Guid organizationId); -} diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index 37a830c92e..41622c24b7 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.Entities; using Bit.Core.Enums; @@ -87,4 +88,24 @@ public interface IOrganizationUserRepository : IRepository> GetManyDetailsByRoleAsync(Guid organizationId, OrganizationUserType role); Task CreateManyAsync(IEnumerable organizationUserCollection); + + /// + /// It will only confirm if the user is in the `Accepted` state. + /// + /// This is an idempotent operation. + /// + /// Accepted OrganizationUser to confirm + /// True, if the user was updated. False, if not performed. + Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm); + + /// + /// Returns the OrganizationUserUserDetails if found. + /// + /// The id of the organization + /// The id of the User to fetch + /// OrganizationUserUserDetails of the specified user or null if not found + /// + /// Similar to GetByOrganizationAsync, but returns the user details. + /// + Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId); } diff --git a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs index 9f5c7f3fc4..d479809b89 100644 --- a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs +++ b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs @@ -20,17 +20,6 @@ public interface IPolicyRepository : IRepository Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); Task> GetManyByOrganizationIdAsync(Guid organizationId); Task> GetManyByUserIdAsync(Guid userId); - /// - /// Gets all PolicyDetails for a user for all policy types. - /// - /// - /// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced - /// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan - /// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement. - /// This is consumed by to create requirements for specific policy types. - /// You probably do not want to call it directly. - /// - Task> GetPolicyDetailsByUserId(Guid userId); /// /// Retrieves of the specified diff --git a/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs b/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs index 7bc4125778..0a640b7530 100644 --- a/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IProviderUserRepository.cs @@ -12,6 +12,7 @@ public interface IProviderUserRepository : IRepository Task GetCountByProviderAsync(Guid providerId, string email, bool onlyRegisteredUsers); Task> GetManyAsync(IEnumerable ids); Task> GetManyByUserAsync(Guid userId); + Task> GetManyByManyUsersAsync(IEnumerable userIds); Task GetByProviderUserAsync(Guid providerId, Guid userId); Task> GetManyByProviderAsync(Guid providerId, ProviderUserType? type = null); Task> GetManyDetailsByProviderAsync(Guid providerId, ProviderUserStatusType? status = null); diff --git a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs b/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs deleted file mode 100644 index b80b518223..0000000000 --- a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -namespace Bit.Core.Services; - -public interface IEventIntegrationPublisher : IAsyncDisposable -{ - Task PublishAsync(IIntegrationMessage message); - Task PublishEventAsync(string body); -} diff --git a/src/Core/AdminConsole/Services/IIntegrationConfigurationDetailsCache.cs b/src/Core/AdminConsole/Services/IIntegrationConfigurationDetailsCache.cs deleted file mode 100644 index ad27429112..0000000000 --- a/src/Core/AdminConsole/Services/IIntegrationConfigurationDetailsCache.cs +++ /dev/null @@ -1,14 +0,0 @@ -#nullable enable - -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; - -namespace Bit.Core.Services; - -public interface IIntegrationConfigurationDetailsCache -{ - List GetConfigurationDetails( - Guid organizationId, - IntegrationType integrationType, - EventType eventType); -} diff --git a/src/Core/AdminConsole/Services/IIntegrationHandler.cs b/src/Core/AdminConsole/Services/IIntegrationHandler.cs deleted file mode 100644 index bb10dc01b9..0000000000 --- a/src/Core/AdminConsole/Services/IIntegrationHandler.cs +++ /dev/null @@ -1,74 +0,0 @@ -using System.Globalization; -using System.Net; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -namespace Bit.Core.Services; - -public interface IIntegrationHandler -{ - Task HandleAsync(string json); -} - -public interface IIntegrationHandler : IIntegrationHandler -{ - Task HandleAsync(IntegrationMessage message); -} - -public abstract class IntegrationHandlerBase : IIntegrationHandler -{ - public async Task HandleAsync(string json) - { - var message = IntegrationMessage.FromJson(json); - return await HandleAsync(message ?? throw new ArgumentException("IntegrationMessage was null when created from the provided JSON")); - } - - public abstract Task HandleAsync(IntegrationMessage message); - - protected IntegrationHandlerResult ResultFromHttpResponse( - HttpResponseMessage response, - IntegrationMessage message, - TimeProvider timeProvider) - { - var result = new IntegrationHandlerResult(success: response.IsSuccessStatusCode, message); - - if (response.IsSuccessStatusCode) return result; - - switch (response.StatusCode) - { - case HttpStatusCode.TooManyRequests: - case HttpStatusCode.RequestTimeout: - case HttpStatusCode.InternalServerError: - case HttpStatusCode.BadGateway: - case HttpStatusCode.ServiceUnavailable: - case HttpStatusCode.GatewayTimeout: - result.Retryable = true; - result.FailureReason = response.ReasonPhrase ?? $"Failure with status code: {(int)response.StatusCode}"; - - if (response.Headers.TryGetValues("Retry-After", out var values)) - { - var value = values.FirstOrDefault(); - if (int.TryParse(value, out var seconds)) - { - // Retry-after was specified in seconds. Adjust DelayUntilDate by the requested number of seconds. - result.DelayUntilDate = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime; - } - else if (DateTimeOffset.TryParseExact(value, - "r", // "r" is the round-trip format: RFC1123 - CultureInfo.InvariantCulture, - DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal, - out var retryDate)) - { - // Retry-after was specified as a date. Adjust DelayUntilDate to the specified date. - result.DelayUntilDate = retryDate.UtcDateTime; - } - } - break; - default: - result.Retryable = false; - result.FailureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}"; - break; - } - - return result; - } -} diff --git a/src/Core/AdminConsole/Services/ISlackService.cs b/src/Core/AdminConsole/Services/ISlackService.cs deleted file mode 100644 index ff1e03f051..0000000000 --- a/src/Core/AdminConsole/Services/ISlackService.cs +++ /dev/null @@ -1,11 +0,0 @@ -namespace Bit.Core.Services; - -public interface ISlackService -{ - Task GetChannelIdAsync(string token, string channelName); - Task> GetChannelIdsAsync(string token, List channelNames); - Task GetDmChannelByEmailAsync(string token, string email); - string GetRedirectUrl(string callbackUrl, string state); - Task ObtainTokenViaOAuth(string code, string redirectUrl); - Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId); -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs deleted file mode 100644 index 0a8ab67554..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs +++ /dev/null @@ -1,108 +0,0 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.AdminConsole.Utilities; -using Bit.Core.Enums; -using Bit.Core.Models.Data; -using Bit.Core.Repositories; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.Services; - -public class EventIntegrationHandler( - IntegrationType integrationType, - IEventIntegrationPublisher eventIntegrationPublisher, - IIntegrationFilterService integrationFilterService, - IIntegrationConfigurationDetailsCache configurationCache, - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - ILogger> logger) - : IEventMessageHandler -{ - public async Task HandleEventAsync(EventMessage eventMessage) - { - if (eventMessage.OrganizationId is not Guid organizationId) - { - return; - } - - var configurations = configurationCache.GetConfigurationDetails( - organizationId, - integrationType, - eventMessage.Type); - - foreach (var configuration in configurations) - { - try - { - if (configuration.Filters is string filterJson) - { - // Evaluate filters - if false, then discard and do not process - var filters = JsonSerializer.Deserialize(filterJson) - ?? throw new InvalidOperationException($"Failed to deserialize Filters to FilterGroup"); - if (!integrationFilterService.EvaluateFilterGroup(filters, eventMessage)) - { - continue; - } - } - - // Valid filter - assemble message and publish to Integration topic/exchange - var template = configuration.Template ?? string.Empty; - var context = await BuildContextAsync(eventMessage, template); - var renderedTemplate = IntegrationTemplateProcessor.ReplaceTokens(template, context); - var messageId = eventMessage.IdempotencyId ?? Guid.NewGuid(); - var config = configuration.MergedConfiguration.Deserialize() - ?? throw new InvalidOperationException($"Failed to deserialize to {typeof(T).Name} - bad Configuration"); - - var message = new IntegrationMessage - { - IntegrationType = integrationType, - MessageId = messageId.ToString(), - Configuration = config, - RenderedTemplate = renderedTemplate, - RetryCount = 0, - DelayUntilDate = null - }; - - await eventIntegrationPublisher.PublishAsync(message); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to publish Integration Message for {Type}, check Id {RecordId} for error in Configuration or Filters", - typeof(T).Name, - configuration.Id); - } - } - } - - public async Task HandleManyEventsAsync(IEnumerable eventMessages) - { - foreach (var eventMessage in eventMessages) - { - await HandleEventAsync(eventMessage); - } - } - - private async Task BuildContextAsync(EventMessage eventMessage, string template) - { - var context = new IntegrationTemplateContext(eventMessage); - - if (IntegrationTemplateProcessor.TemplateRequiresUser(template) && eventMessage.UserId.HasValue) - { - context.User = await userRepository.GetByIdAsync(eventMessage.UserId.Value); - } - - if (IntegrationTemplateProcessor.TemplateRequiresActingUser(template) && eventMessage.ActingUserId.HasValue) - { - context.ActingUser = await userRepository.GetByIdAsync(eventMessage.ActingUserId.Value); - } - - if (IntegrationTemplateProcessor.TemplateRequiresOrganization(template) && eventMessage.OrganizationId.HasValue) - { - context.Organization = await organizationRepository.GetByIdAsync(eventMessage.OrganizationId.Value); - } - - return context; - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs deleted file mode 100644 index a542e75a7b..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs +++ /dev/null @@ -1,34 +0,0 @@ -using Bit.Core.Models.Data; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Services; - -public class EventRouteService( - [FromKeyedServices("broadcast")] IEventWriteService broadcastEventWriteService, - [FromKeyedServices("storage")] IEventWriteService storageEventWriteService, - IFeatureService _featureService) : IEventWriteService -{ - public async Task CreateAsync(IEvent e) - { - if (_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations)) - { - await broadcastEventWriteService.CreateAsync(e); - } - else - { - await storageEventWriteService.CreateAsync(e); - } - } - - public async Task CreateManyAsync(IEnumerable e) - { - if (_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations)) - { - await broadcastEventWriteService.CreateManyAsync(e); - } - else - { - await storageEventWriteService.CreateManyAsync(e); - } - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs deleted file mode 100644 index a63efac62f..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationConfigurationDetailsCacheService.cs +++ /dev/null @@ -1,83 +0,0 @@ -using System.Diagnostics; -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.Services; - -public class IntegrationConfigurationDetailsCacheService : BackgroundService, IIntegrationConfigurationDetailsCache -{ - private readonly record struct IntegrationCacheKey(Guid OrganizationId, IntegrationType IntegrationType, EventType? EventType); - private readonly IOrganizationIntegrationConfigurationRepository _repository; - private readonly ILogger _logger; - private readonly TimeSpan _refreshInterval; - private Dictionary> _cache = new(); - - public IntegrationConfigurationDetailsCacheService( - IOrganizationIntegrationConfigurationRepository repository, - GlobalSettings globalSettings, - ILogger logger) - { - _repository = repository; - _logger = logger; - _refreshInterval = TimeSpan.FromMinutes(globalSettings.EventLogging.IntegrationCacheRefreshIntervalMinutes); - } - - public List GetConfigurationDetails( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) - { - var specificKey = new IntegrationCacheKey(organizationId, integrationType, eventType); - var allEventsKey = new IntegrationCacheKey(organizationId, integrationType, null); - - var results = new List(); - - if (_cache.TryGetValue(specificKey, out var specificConfigs)) - { - results.AddRange(specificConfigs); - } - if (_cache.TryGetValue(allEventsKey, out var fallbackConfigs)) - { - results.AddRange(fallbackConfigs); - } - - return results; - } - - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - await RefreshAsync(); - - var timer = new PeriodicTimer(_refreshInterval); - while (await timer.WaitForNextTickAsync(stoppingToken)) - { - await RefreshAsync(); - } - } - - internal async Task RefreshAsync() - { - var stopwatch = Stopwatch.StartNew(); - try - { - var newCache = (await _repository.GetAllConfigurationDetailsAsync()) - .GroupBy(x => new IntegrationCacheKey(x.OrganizationId, x.IntegrationType, x.EventType)) - .ToDictionary(g => g.Key, g => g.ToList()); - _cache = newCache; - - stopwatch.Stop(); - _logger.LogInformation( - "[IntegrationConfigurationDetailsCacheService] Refreshed successfully: {Count} entries in {Duration}ms", - newCache.Count, - stopwatch.Elapsed.TotalMilliseconds); - } - catch (Exception ex) - { - _logger.LogError("[IntegrationConfigurationDetailsCacheService] Refresh failed: {ex}", ex); - } - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs deleted file mode 100644 index 2d29494afc..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs +++ /dev/null @@ -1,19 +0,0 @@ -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; - -namespace Bit.Core.Services; - -public class SlackIntegrationHandler( - ISlackService slackService) - : IntegrationHandlerBase -{ - public override async Task HandleAsync(IntegrationMessage message) - { - await slackService.SendSlackMessageByChannelIdAsync( - message.Configuration.Token, - message.RenderedTemplate, - message.Configuration.ChannelId - ); - - return new IntegrationHandlerResult(success: true, message: message); - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 1b52ad8cff..e1fcbb970d 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -21,6 +21,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -47,7 +48,7 @@ public class OrganizationService : IOrganizationService private readonly IPushNotificationService _pushNotificationService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly ISsoUserRepository _ssoUserRepository; @@ -74,7 +75,7 @@ public class OrganizationService : IOrganizationService IPushNotificationService pushNotificationService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, ISsoUserRepository ssoUserRepository, @@ -148,7 +149,7 @@ public class OrganizationService : IOrganizationService } var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, organization, storageAdjustmentGb, - plan.PasswordManager.StripeStoragePlanId); + plan.PasswordManager.StripeStoragePlanId, plan.PasswordManager.BaseStorageGb); await ReplaceAndUpdateCacheAsync(organization); return secret; } @@ -358,7 +359,7 @@ public class OrganizationService : IOrganizationService { var newDisplayName = organization.DisplayName(); - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, + await _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, new CustomerUpdateOptions { Email = organization.BillingEmail, diff --git a/src/Core/AdminConsole/Services/OrganizationFactory.cs b/src/Core/AdminConsole/Services/OrganizationFactory.cs index afb3931ec4..0c64a27431 100644 --- a/src/Core/AdminConsole/Services/OrganizationFactory.cs +++ b/src/Core/AdminConsole/Services/OrganizationFactory.cs @@ -61,6 +61,8 @@ public static class OrganizationFactory claimsPrincipal.GetValue(OrganizationLicenseConstants.UseOrganizationDomains), UseAdminSponsoredFamilies = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAdminSponsoredFamilies), + UseAutomaticUserConfirmation = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation), + UsePhishingBlocker = claimsPrincipal.GetValue(OrganizationLicenseConstants.UsePhishingBlocker), }; public static Organization Create( @@ -110,5 +112,7 @@ public static class OrganizationFactory UseRiskInsights = license.UseRiskInsights, UseOrganizationDomains = license.UseOrganizationDomains, UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation, + UsePhishingBlocker = license.UsePhishingBlocker, }; } diff --git a/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs b/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs index b561e58a86..7fc8013c15 100644 --- a/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs +++ b/src/Core/AdminConsole/Utilities/IntegrationTemplateProcessor.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Text.RegularExpressions; +using System.Text.RegularExpressions; namespace Bit.Core.AdminConsole.Utilities; @@ -26,7 +24,7 @@ public static partial class IntegrationTemplateProcessor return match.Value; // Return unknown keys as keys - i.e. #Key# } - return property?.GetValue(values)?.ToString() ?? ""; + return property.GetValue(values)?.ToString() ?? string.Empty; }); } @@ -38,7 +36,8 @@ public static partial class IntegrationTemplateProcessor } return template.Contains("#UserName#", StringComparison.Ordinal) - || template.Contains("#UserEmail#", StringComparison.Ordinal); + || template.Contains("#UserEmail#", StringComparison.Ordinal) + || template.Contains("#UserType#", StringComparison.Ordinal); } public static bool TemplateRequiresActingUser(string template) @@ -49,7 +48,18 @@ public static partial class IntegrationTemplateProcessor } return template.Contains("#ActingUserName#", StringComparison.Ordinal) - || template.Contains("#ActingUserEmail#", StringComparison.Ordinal); + || template.Contains("#ActingUserEmail#", StringComparison.Ordinal) + || template.Contains("#ActingUserType#", StringComparison.Ordinal); + } + + public static bool TemplateRequiresGroup(string template) + { + if (string.IsNullOrEmpty(template)) + { + return false; + } + + return template.Contains("#GroupName#", StringComparison.Ordinal); } public static bool TemplateRequiresOrganization(string template) diff --git a/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs new file mode 100644 index 0000000000..84e63f2a20 --- /dev/null +++ b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs @@ -0,0 +1,81 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.Utilities; + +public static class PolicyDataValidator +{ + /// + /// Validates and serializes policy data based on the policy type. + /// + /// The policy data to validate + /// The type of policy + /// Serialized JSON string if data is valid, null if data is null or empty + /// Thrown when data validation fails + public static string? ValidateAndSerialize(Dictionary? data, PolicyType policyType) + { + if (data == null || data.Count == 0) + { + return null; + } + + try + { + var json = JsonSerializer.Serialize(data); + + switch (policyType) + { + case PolicyType.MasterPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.SendOptions: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.ResetPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + } + + return json; + } + catch (JsonException ex) + { + var fieldInfo = !string.IsNullOrEmpty(ex.Path) ? $": field '{ex.Path}' has invalid type" : ""; + throw new BadRequestException($"Invalid data for {policyType} policy{fieldInfo}."); + } + } + + /// + /// Validates and deserializes policy metadata based on the policy type. + /// + /// The policy metadata to validate + /// The type of policy + /// Deserialized metadata model, or EmptyMetadataModel if metadata is null, empty, or validation fails + public static IPolicyMetadataModel ValidateAndDeserializeMetadata(Dictionary? metadata, PolicyType policyType) + { + if (metadata == null || metadata.Count == 0) + { + return new EmptyMetadataModel(); + } + + try + { + var json = JsonSerializer.Serialize(metadata); + + return policyType switch + { + PolicyType.OrganizationDataOwnership => + CoreHelpers.LoadClassFromJsonData(json), + _ => new EmptyMetadataModel() + }; + } + catch (JsonException) + { + return new EmptyMetadataModel(); + } + } +} diff --git a/src/Core/AdminConsole/Utilities/v2/Errors.cs b/src/Core/AdminConsole/Utilities/v2/Errors.cs new file mode 100644 index 0000000000..c1c66b2630 --- /dev/null +++ b/src/Core/AdminConsole/Utilities/v2/Errors.cs @@ -0,0 +1,15 @@ +namespace Bit.Core.AdminConsole.Utilities.v2; + +/// +/// A strongly typed error containing a reason that an action failed. +/// This is used for business logic validation and other expected errors, not exceptions. +/// +public abstract record Error(string Message); +/// +/// An type that maps to a NotFoundResult at the api layer. +/// +/// +public abstract record NotFoundError(string Message) : Error(Message); + +public abstract record BadRequestError(string Message) : Error(Message); +public abstract record InternalError(string Message) : Error(Message); diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs b/src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs similarity index 94% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs rename to src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs index fbb00a908a..fb1bd16b2d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/CommandResult.cs +++ b/src/Core/AdminConsole/Utilities/v2/Results/CommandResult.cs @@ -1,7 +1,7 @@ using OneOf; using OneOf.Types; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +namespace Bit.Core.AdminConsole.Utilities.v2.Results; /// /// Represents the result of a command. @@ -39,4 +39,3 @@ public record BulkCommandResult(Guid Id, CommandResult Result); /// A wrapper for with an ID, to identify the result in bulk operations. /// public record BulkCommandResult(Guid Id, CommandResult Result); - diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs b/src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs similarity index 94% rename from src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs rename to src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs index c84a0aeda1..e28eac9a1c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccount/ValidationResult.cs +++ b/src/Core/AdminConsole/Utilities/v2/Validation/ValidationResult.cs @@ -1,7 +1,7 @@ using OneOf; using OneOf.Types; -namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +namespace Bit.Core.AdminConsole.Utilities.v2.Validation; /// /// Represents the result of validating a request. diff --git a/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs b/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs new file mode 100644 index 0000000000..bcc4b851c0 --- /dev/null +++ b/src/Core/Auth/Attributes/MarketingInitiativeValidationAttribute.cs @@ -0,0 +1,29 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.Auth.Models.Api.Request.Accounts; + +namespace Bit.Core.Auth.Attributes; + +public class MarketingInitiativeValidationAttribute : ValidationAttribute +{ + private static readonly string[] _acceptedValues = [MarketingInitiativeConstants.Premium]; + + public MarketingInitiativeValidationAttribute() + { + ErrorMessage = $"Marketing initiative type must be one of: {string.Join(", ", _acceptedValues)}"; + } + + public override bool IsValid(object? value) + { + if (value == null) + { + return true; + } + + if (value is not string str) + { + return false; + } + + return _acceptedValues.Contains(str); + } +} diff --git a/src/Core/Auth/Entities/AuthRequest.cs b/src/Core/Auth/Entities/AuthRequest.cs index 2117c575c0..38dc0534c1 100644 --- a/src/Core/Auth/Entities/AuthRequest.cs +++ b/src/Core/Auth/Entities/AuthRequest.cs @@ -49,11 +49,9 @@ public class AuthRequest : ITableObject public bool IsExpired() { - // TODO: PM-24252 - consider using TimeProvider for better mocking in tests return GetExpirationDate() < DateTime.UtcNow; } - // TODO: PM-24252 - this probably belongs in a service. public bool IsValidForAuthentication(Guid userId, string password) { diff --git a/src/Core/Auth/Identity/Policies.cs b/src/Core/Auth/Identity/Policies.cs index b2d94b0a6e..698a890006 100644 --- a/src/Core/Auth/Identity/Policies.cs +++ b/src/Core/Auth/Identity/Policies.cs @@ -5,12 +5,94 @@ public static class Policies /// /// Policy for managing access to the Send feature. /// - public const string Send = "Send"; // [Authorize(Policy = Policies.Send)] - public const string Application = "Application"; // [Authorize(Policy = Policies.Application)] - public const string Web = "Web"; // [Authorize(Policy = Policies.Web)] - public const string Push = "Push"; // [Authorize(Policy = Policies.Push)] + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Send)] + /// + /// + /// + public const string Send = "Send"; + + /// + /// Policy to manage access to general API endpoints. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Application)] + /// + /// + /// + public const string Application = "Application"; + + /// + /// Policy to manage access to API endpoints intended for use by the Web Vault and browser extension only. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Web)] + /// + /// + /// + public const string Web = "Web"; + + /// + /// Policy to restrict access to API endpoints for the Push feature. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Push)] + /// + /// + /// + public const string Push = "Push"; + + // TODO: This is unused public const string Licensing = "Licensing"; // [Authorize(Policy = Policies.Licensing)] - public const string Organization = "Organization"; // [Authorize(Policy = Policies.Organization)] - public const string Installation = "Installation"; // [Authorize(Policy = Policies.Installation)] - public const string Secrets = "Secrets"; // [Authorize(Policy = Policies.Secrets)] + + /// + /// Policy to restrict access to API endpoints related to the Organization features. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Licensing)] + /// + /// + /// + public const string Organization = "Organization"; + + /// + /// Policy to restrict access to API endpoints related to the setting up new installations. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Installation)] + /// + /// + /// + public const string Installation = "Installation"; + + /// + /// Policy to restrict access to API endpoints for Secrets Manager features. + /// + /// + /// + /// Can be used with the Authorize attribute, for example: + /// + /// [Authorize(Policy = Policies.Secrets)] + /// + /// + /// + public const string Secrets = "Secrets"; } diff --git a/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs b/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs index 70aba8ef75..f6ef3a5dd0 100644 --- a/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs +++ b/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs @@ -65,7 +65,7 @@ public class EmailTokenProvider : IUserTwoFactorTokenProvider } var code = Encoding.UTF8.GetString(cachedValue); - var valid = string.Equals(token, code); + var valid = CoreHelpers.FixedTimeEquals(token, code); if (valid) { await _distributedCache.RemoveAsync(cacheKey); diff --git a/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs b/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs index b6280e13fe..ae394f817e 100644 --- a/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs +++ b/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs @@ -64,7 +64,7 @@ public class OtpTokenProvider( } var code = Encoding.UTF8.GetString(cachedValue); - var valid = string.Equals(token, code); + var valid = CoreHelpers.FixedTimeEquals(token, code); if (valid) { await _distributedCache.RemoveAsync(cacheKey); diff --git a/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs b/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs deleted file mode 100644 index f8caad448b..0000000000 --- a/src/Core/Auth/LoginFeatures/LoginServiceCollectionExtensions.cs +++ /dev/null @@ -1,14 +0,0 @@ -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin; -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Auth.LoginFeatures; - -public static class LoginServiceCollectionExtensions -{ - public static void AddLoginServices(this IServiceCollection services) - { - services.AddScoped(); - } -} - diff --git a/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs b/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs deleted file mode 100644 index e5da1b06d8..0000000000 --- a/src/Core/Auth/LoginFeatures/PasswordlessLogin/Interfaces/IVerifyAuthRequest.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; - -public interface IVerifyAuthRequestCommand -{ - Task VerifyAuthRequestAsync(Guid authRequestId, string accessCode); -} diff --git a/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs b/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs deleted file mode 100644 index 7def7fea76..0000000000 --- a/src/Core/Auth/LoginFeatures/PasswordlessLogin/VerifyAuthRequest.cs +++ /dev/null @@ -1,25 +0,0 @@ -using Bit.Core.Auth.LoginFeatures.PasswordlessLogin.Interfaces; -using Bit.Core.Repositories; -using Bit.Core.Utilities; - -namespace Bit.Core.Auth.LoginFeatures.PasswordlessLogin; - -public class VerifyAuthRequestCommand : IVerifyAuthRequestCommand -{ - private readonly IAuthRequestRepository _authRequestRepository; - - public VerifyAuthRequestCommand(IAuthRequestRepository authRequestRepository) - { - _authRequestRepository = authRequestRepository; - } - - public async Task VerifyAuthRequestAsync(Guid authRequestId, string accessCode) - { - var authRequest = await _authRequestRepository.GetByIdAsync(authRequestId); - if (authRequest == null || !CoreHelpers.FixedTimeEquals(authRequest.AccessCode, accessCode)) - { - return false; - } - return true; - } -} diff --git a/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs index f89b67f3c5..85ddef44ce 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/KeysRequestModel.cs @@ -3,17 +3,22 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; public class KeysRequestModel { + [Obsolete("Use AccountKeys.AccountPublicKey instead")] [Required] public string PublicKey { get; set; } + [Obsolete("Use AccountKeys.UserKeyEncryptedAccountPrivateKey instead")] [Required] public string EncryptedPrivateKey { get; set; } + public AccountKeysRequestModel AccountKeys { get; set; } + [Obsolete("Use SetAccountKeysForUserCommand instead")] public User ToUser(User existingUser) { if (string.IsNullOrWhiteSpace(PublicKey) || string.IsNullOrWhiteSpace(EncryptedPrivateKey)) diff --git a/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs b/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs new file mode 100644 index 0000000000..ab2d252dc8 --- /dev/null +++ b/src/Core/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstants.cs @@ -0,0 +1,10 @@ +namespace Bit.Core.Auth.Models.Api.Request.Accounts; + +public static class MarketingInitiativeConstants +{ + /// + /// Indicates that the user began the registration process on a marketing page designed + /// to streamline users who intend to setup a premium subscription after registration. + /// + public const string Premium = "premium"; +} diff --git a/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs index 75a4da081a..638565ecfe 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/RegisterSendVerificationEmailRequestModel.cs @@ -1,5 +1,6 @@ #nullable enable using System.ComponentModel.DataAnnotations; +using Bit.Core.Auth.Attributes; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; @@ -11,4 +12,6 @@ public class RegisterSendVerificationEmailRequestModel [StringLength(256)] public required string Email { get; set; } public bool ReceiveMarketingEmails { get; set; } + [MarketingInitiativeValidation] + public string? FromMarketing { get; set; } } diff --git a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs index bd8542e8bf..aa8a298200 100644 --- a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs +++ b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs @@ -1,5 +1,5 @@ using System.Text.Json.Serialization; -using Bit.Core.KeyManagement.Models.Response; +using Bit.Core.KeyManagement.Models.Api.Response; using Bit.Core.Models.Api; namespace Bit.Core.Auth.Models.Api.Response; diff --git a/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs b/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs index f04a1181c4..5be7ed481f 100644 --- a/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs +++ b/src/Core/Auth/Models/Business/Tokenables/OrgUserInviteTokenable.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.Entities; using Bit.Core.Tokens; @@ -26,7 +23,7 @@ public class OrgUserInviteTokenable : ExpiringTokenable public string Identifier { get; set; } = TokenIdentifier; public Guid OrgUserId { get; set; } - public string OrgUserEmail { get; set; } + public string? OrgUserEmail { get; set; } [JsonConstructor] public OrgUserInviteTokenable() diff --git a/src/Core/Auth/Models/ITwoFactorProvidersUser.cs b/src/Core/Auth/Models/ITwoFactorProvidersUser.cs index 5cf137b76f..816d460572 100644 --- a/src/Core/Auth/Models/ITwoFactorProvidersUser.cs +++ b/src/Core/Auth/Models/ITwoFactorProvidersUser.cs @@ -1,14 +1,14 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Enums; using Bit.Core.Services; namespace Bit.Core.Auth.Models; +/// +/// An interface representing a user entity that supports two-factor providers +/// public interface ITwoFactorProvidersUser { - string TwoFactorProviders { get; } + string? TwoFactorProviders { get; } /// /// Get the two factor providers for the user. Currently it can be assumed providers are enabled /// if they exists in the dictionary. When two factor providers are disabled they are removed @@ -16,7 +16,10 @@ public interface ITwoFactorProvidersUser /// /// /// Dictionary of providers with the type enum as the key - Dictionary GetTwoFactorProviders(); + Dictionary? GetTwoFactorProviders(); + /// + /// The unique `UserId` of the user entity for which there are two-factor providers configured. + /// + /// The unique identifier for the user Guid? GetUserId(); - bool GetPremium(); } diff --git a/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs b/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs index fe42093111..5c0efeb73f 100644 --- a/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs +++ b/src/Core/Auth/Models/Mail/RegisterVerifyEmail.cs @@ -15,11 +15,13 @@ public class RegisterVerifyEmail : BaseMailModel // so we must land on a redirect connector which will redirect to the finish signup page. // Note 3: The use of a fragment to indicate the redirect url is to prevent the query string from being logged by // proxies and servers. It also helps reduce open redirect vulnerabilities. - public string Url => string.Format("{0}/redirect-connector.html#finish-signup?token={1}&email={2}&fromEmail=true", + public string Url => string.Format("{0}/redirect-connector.html#finish-signup?token={1}&email={2}&fromEmail=true{3}", WebVaultUrl, Token, - Email); + Email, + !string.IsNullOrEmpty(FromMarketing) ? $"&fromMarketing={FromMarketing}" : string.Empty); public string Token { get; set; } public string Email { get; set; } + public string FromMarketing { get; set; } } diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index fe8d9bdd6e..0cb8b68042 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -3,9 +3,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; @@ -24,7 +25,7 @@ public class SsoConfigService : ISsoConfigService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; - private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public SsoConfigService( ISsoConfigRepository ssoConfigRepository, @@ -32,14 +33,14 @@ public class SsoConfigService : ISsoConfigService IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IEventService eventService, - ISavePolicyCommand savePolicyCommand) + IVNextSavePolicyCommand vNextSavePolicyCommand) { _ssoConfigRepository = ssoConfigRepository; _policyRepository = policyRepository; _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _eventService = eventService; - _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } public async Task SaveAsync(SsoConfig config, Organization organization) @@ -67,13 +68,12 @@ public class SsoConfigService : ISsoConfigService // Automatically enable account recovery, SSO required, and single org policies if trusted device encryption is selected if (config.GetData().MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption) { - - await _savePolicyCommand.SaveAsync(new() + var singleOrgPolicy = new PolicyUpdate { OrganizationId = config.OrganizationId, Type = PolicyType.SingleOrg, Enabled = true - }); + }; var resetPasswordPolicy = new PolicyUpdate { @@ -82,14 +82,18 @@ public class SsoConfigService : ISsoConfigService Enabled = true, }; resetPasswordPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true }); - await _savePolicyCommand.SaveAsync(resetPasswordPolicy); - await _savePolicyCommand.SaveAsync(new() + var requireSsoPolicy = new PolicyUpdate { OrganizationId = config.OrganizationId, Type = PolicyType.RequireSso, Enabled = true - }); + }; + + var performedBy = new SystemUser(EventSystemUser.Unknown); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(singleOrgPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(resetPasswordPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(requireSsoPolicy, performedBy)); } await LogEventsAsync(config, oldConfig); diff --git a/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c932eb0c34 --- /dev/null +++ b/src/Core/Auth/Sso/IUserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,23 @@ +using Bit.Core.Entities; + +namespace Bit.Core.Auth.Sso; + +/// +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public interface IUserSsoOrganizationIdentifierQuery +{ + /// + /// Retrieves the SSO organization identifier for a confirmed organization user. + /// If there is more than one organization a User is associated with, we return null. If there are more than one + /// organization there is no way to know which organization the user wishes to authenticate with. + /// Owners and Admins who are not subject to the SSO required policy cannot utilize this flow, since they may have + /// multiple organizations with different SSO configurations. + /// + /// The ID of the to retrieve the SSO organization for. _Not_ an . + /// + /// The organization identifier if the user is a confirmed member of an organization with SSO configured, + /// otherwise null + /// + Task GetSsoOrganizationIdentifierAsync(Guid userId); +} diff --git a/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs new file mode 100644 index 0000000000..c0751e1f1a --- /dev/null +++ b/src/Core/Auth/Sso/UserSsoOrganizationIdentifierQuery.cs @@ -0,0 +1,38 @@ +using Bit.Core.Enums; +using Bit.Core.Repositories; + +namespace Bit.Core.Auth.Sso; + +/// +/// TODO : PM-28846 review data structures as they relate to this query +/// Query to retrieve the SSO organization identifier that a user is a confirmed member of. +/// +public class UserSsoOrganizationIdentifierQuery( + IOrganizationUserRepository _organizationUserRepository, + IOrganizationRepository _organizationRepository) : IUserSsoOrganizationIdentifierQuery +{ + /// + public async Task GetSsoOrganizationIdentifierAsync(Guid userId) + { + // Get all confirmed organization memberships for the user + var organizationUsers = await _organizationUserRepository.GetManyByUserAsync(userId); + + // we can only confidently return the correct SsoOrganizationIdentifier if there is exactly one Organization. + // The user must also be in the Confirmed status. + var confirmedOrgUsers = organizationUsers.Where(ou => ou.Status == OrganizationUserStatusType.Confirmed); + if (confirmedOrgUsers.Count() != 1) + { + return null; + } + + var confirmedOrgUser = confirmedOrgUsers.Single(); + var organization = await _organizationRepository.GetByIdAsync(confirmedOrgUser.OrganizationId); + + if (organization == null) + { + return null; + } + + return organization.Identifier; + } +} diff --git a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs index 62dd9dd293..97c2eabd3c 100644 --- a/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/IRegisterUserCommand.cs @@ -1,4 +1,5 @@ -using Bit.Core.Entities; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Entities; using Microsoft.AspNetCore.Identity; namespace Bit.Core.Auth.UserFeatures.Registration; @@ -14,6 +15,15 @@ public interface IRegisterUserCommand /// public Task RegisterUser(User user); + /// + /// Creates a new user, sends a welcome email, and raises the signup reference event. + /// This method is used by SSO auto-provisioned organization Users. + /// + /// The to create + /// The associated with the user + /// + Task RegisterSSOAutoProvisionedUserAsync(User user, Organization organization); + /// /// Creates a new user with a given master password hash, sends a welcome email (differs based on initiation path), /// and raises the signup reference event. Optionally accepts an org invite token and org user id to associate diff --git a/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs b/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs index b623b8cab3..2a224b9eb9 100644 --- a/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/ISendVerificationEmailForRegistrationCommand.cs @@ -3,5 +3,5 @@ namespace Bit.Core.Auth.UserFeatures.Registration; public interface ISendVerificationEmailForRegistrationCommand { - public Task Run(string email, string? name, bool receiveMarketingEmails); + public Task Run(string email, string? name, bool receiveMarketingEmails, string? fromMarketing); } diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs index 991be2b764..4a0e9c2cf5 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs @@ -1,11 +1,11 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -16,15 +16,20 @@ using Bit.Core.Tokens; using Bit.Core.Utilities; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Identity; +using Microsoft.Extensions.Logging; using Newtonsoft.Json; namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; public class RegisterUserCommand : IRegisterUserCommand { + private readonly ILogger _logger; private readonly IGlobalSettings _globalSettings; private readonly IOrganizationUserRepository _organizationUserRepository; + private readonly IOrganizationRepository _organizationRepository; private readonly IPolicyRepository _policyRepository; + private readonly IOrganizationDomainRepository _organizationDomainRepository; + private readonly IFeatureService _featureService; private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IDataProtectorTokenFactory _registrationEmailVerificationTokenDataFactory; @@ -41,21 +46,28 @@ public class RegisterUserCommand : IRegisterUserCommand private readonly string _disabledUserRegistrationExceptionMsg = "Open registration has been disabled by the system administrator."; public RegisterUserCommand( - IGlobalSettings globalSettings, - IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, - IDataProtectionProvider dataProtectionProvider, - IDataProtectorTokenFactory orgUserInviteTokenDataFactory, - IDataProtectorTokenFactory registrationEmailVerificationTokenDataFactory, - IUserService userService, - IMailService mailService, - IValidateRedemptionTokenCommand validateRedemptionTokenCommand, - IDataProtectorTokenFactory emergencyAccessInviteTokenDataFactory - ) + ILogger logger, + IGlobalSettings globalSettings, + IOrganizationUserRepository organizationUserRepository, + IOrganizationRepository organizationRepository, + IPolicyRepository policyRepository, + IOrganizationDomainRepository organizationDomainRepository, + IFeatureService featureService, + IDataProtectionProvider dataProtectionProvider, + IDataProtectorTokenFactory orgUserInviteTokenDataFactory, + IDataProtectorTokenFactory registrationEmailVerificationTokenDataFactory, + IUserService userService, + IMailService mailService, + IValidateRedemptionTokenCommand validateRedemptionTokenCommand, + IDataProtectorTokenFactory emergencyAccessInviteTokenDataFactory) { + _logger = logger; _globalSettings = globalSettings; _organizationUserRepository = organizationUserRepository; + _organizationRepository = organizationRepository; _policyRepository = policyRepository; + _organizationDomainRepository = organizationDomainRepository; + _featureService = featureService; _organizationServiceDataProtector = dataProtectionProvider.CreateProtector( "OrganizationServiceDataProtector"); @@ -69,11 +81,13 @@ public class RegisterUserCommand : IRegisterUserCommand _emergencyAccessInviteTokenDataFactory = emergencyAccessInviteTokenDataFactory; _providerServiceDataProtector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); + _featureService = featureService; } - public async Task RegisterUser(User user) { + await ValidateEmailDomainNotBlockedAsync(user.Email); + var result = await _userService.CreateUserAsync(user); if (result == IdentityResult.Success) { @@ -83,11 +97,30 @@ public class RegisterUserCommand : IRegisterUserCommand return result; } + public async Task RegisterSSOAutoProvisionedUserAsync(User user, Organization organization) + { + // Validate that the email domain is not blocked by another organization's policy + await ValidateEmailDomainNotBlockedAsync(user.Email, organization.Id); + + var result = await _userService.CreateUserAsync(user); + if (result == IdentityResult.Success) + { + await SendWelcomeEmailAsync(user, organization); + } + + return result; + } + public async Task RegisterUserViaOrganizationInviteToken(User user, string masterPasswordHash, string orgInviteToken, Guid? orgUserId) { - ValidateOrgInviteToken(orgInviteToken, orgUserId, user); - await SetUserEmail2FaIfOrgPolicyEnabledAsync(orgUserId, user); + TryValidateOrgInviteToken(orgInviteToken, orgUserId, user); + var orgUser = await SetUserEmail2FaIfOrgPolicyEnabledAsync(orgUserId, user); + if (orgUser == null && orgUserId.HasValue) + { + throw new BadRequestException("Invalid organization user invitation."); + } + await ValidateEmailDomainNotBlockedAsync(user.Email, orgUser?.OrganizationId); user.ApiKey = CoreHelpers.SecureRandomString(30); @@ -97,16 +130,17 @@ public class RegisterUserCommand : IRegisterUserCommand } var result = await _userService.CreateUserAsync(user, masterPasswordHash); + var organization = await GetOrganizationUserOrganization(orgUserId ?? Guid.Empty, orgUser); if (result == IdentityResult.Success) { var sentWelcomeEmail = false; if (!string.IsNullOrEmpty(user.ReferenceData)) { - var referenceData = JsonConvert.DeserializeObject>(user.ReferenceData); + var referenceData = JsonConvert.DeserializeObject>(user.ReferenceData) ?? []; if (referenceData.TryGetValue("initiationPath", out var value)) { - var initiationPath = value.ToString(); - await SendAppropriateWelcomeEmailAsync(user, initiationPath); + var initiationPath = value.ToString() ?? string.Empty; + await SendAppropriateWelcomeEmailAsync(user, initiationPath, organization); sentWelcomeEmail = true; if (!string.IsNullOrEmpty(initiationPath)) { @@ -117,14 +151,22 @@ public class RegisterUserCommand : IRegisterUserCommand if (!sentWelcomeEmail) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user, organization); } } return result; } - private void ValidateOrgInviteToken(string orgInviteToken, Guid? orgUserId, User user) + /// + /// This method attempts to validate the org invite token if provided. If the token is invalid an exception is thrown. + /// If there is no exception it is assumed the token is valid or not provided and open registration is allowed. + /// + /// The organization invite token. + /// The organization user ID. + /// The user being registered. + /// If validation fails then an exception is thrown. + private void TryValidateOrgInviteToken(string orgInviteToken, Guid? orgUserId, User user) { var orgInviteTokenProvided = !string.IsNullOrWhiteSpace(orgInviteToken); @@ -137,7 +179,6 @@ public class RegisterUserCommand : IRegisterUserCommand } // Token data is invalid - if (_globalSettings.DisableUserRegistration) { throw new BadRequestException(_disabledUserRegistrationExceptionMsg); @@ -147,7 +188,6 @@ public class RegisterUserCommand : IRegisterUserCommand } // no token data or missing token data - // Throw if open registration is disabled and there isn't an org invite token or an org user id // as you can't register without them. if (_globalSettings.DisableUserRegistration) @@ -171,12 +211,20 @@ public class RegisterUserCommand : IRegisterUserCommand // If both orgInviteToken && orgUserId are missing, then proceed with open registration } + /// + /// Validates the org invite token using the new tokenable logic first, then falls back to the old token validation logic for backwards compatibility. + /// Will set the out parameter organizationWelcomeEmailDetails if the new token is valid. If the token is invalid then no welcome email needs to be sent + /// so the out parameter is set to null. + /// + /// Invite token + /// Inviting Organization UserId + /// User email + /// true if the token is valid false otherwise private bool IsOrgInviteTokenValid(string orgInviteToken, Guid orgUserId, string userEmail) { // TODO: PM-4142 - remove old token validation logic once 3 releases of backwards compatibility are complete var newOrgInviteTokenValid = OrgUserInviteTokenable.ValidateOrgUserInviteStringToken( _orgUserInviteTokenDataFactory, orgInviteToken, orgUserId, userEmail); - return newOrgInviteTokenValid || CoreHelpers.UserInviteTokenIsValid( _organizationServiceDataProtector, orgInviteToken, userEmail, orgUserId, _globalSettings); } @@ -187,11 +235,12 @@ public class RegisterUserCommand : IRegisterUserCommand /// /// The optional org user id /// The newly created user object which could be modified - private async Task SetUserEmail2FaIfOrgPolicyEnabledAsync(Guid? orgUserId, User user) + /// The organization user if one exists for the provided org user id, null otherwise + private async Task SetUserEmail2FaIfOrgPolicyEnabledAsync(Guid? orgUserId, User user) { if (!orgUserId.HasValue) { - return; + return null; } var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); @@ -213,10 +262,11 @@ public class RegisterUserCommand : IRegisterUserCommand _userService.SetTwoFactorProvider(user, TwoFactorProviderType.Email); } } + return orgUser; } - private async Task SendAppropriateWelcomeEmailAsync(User user, string initiationPath) + private async Task SendAppropriateWelcomeEmailAsync(User user, string initiationPath, Organization? organization) { var isFromMarketingWebsite = initiationPath.Contains("Secrets Manager trial"); @@ -226,15 +276,15 @@ public class RegisterUserCommand : IRegisterUserCommand } else { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user, organization); } } public async Task RegisterUserViaEmailVerificationToken(User user, string masterPasswordHash, string emailVerificationToken) { - ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); var tokenable = ValidateRegistrationEmailVerificationTokenable(emailVerificationToken, user.Email); @@ -245,7 +295,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -255,6 +305,7 @@ public class RegisterUserCommand : IRegisterUserCommand string orgSponsoredFreeFamilyPlanInviteToken) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); await ValidateOrgSponsoredFreeFamilyPlanInviteToken(orgSponsoredFreeFamilyPlanInviteToken, user.Email); user.EmailVerified = true; @@ -263,7 +314,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -275,6 +326,7 @@ public class RegisterUserCommand : IRegisterUserCommand string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); ValidateAcceptEmergencyAccessInviteToken(acceptEmergencyAccessInviteToken, acceptEmergencyAccessId, user.Email); user.EmailVerified = true; @@ -283,7 +335,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -293,6 +345,7 @@ public class RegisterUserCommand : IRegisterUserCommand string providerInviteToken, Guid providerUserId) { ValidateOpenRegistrationAllowed(); + await ValidateEmailDomainNotBlockedAsync(user.Email); ValidateProviderInviteToken(providerInviteToken, providerUserId, user.Email); user.EmailVerified = true; @@ -301,7 +354,7 @@ public class RegisterUserCommand : IRegisterUserCommand var result = await _userService.CreateUserAsync(user, masterPasswordHash); if (result == IdentityResult.Success) { - await _mailService.SendWelcomeEmailAsync(user); + await SendWelcomeEmailAsync(user); } return result; @@ -357,4 +410,79 @@ public class RegisterUserCommand : IRegisterUserCommand return tokenable; } + + private async Task ValidateEmailDomainNotBlockedAsync(string email, Guid? excludeOrganizationId = null) + { + // Only check if feature flag is enabled + if (!_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + return; + } + + var emailDomain = EmailValidation.GetDomain(email); + + var isDomainBlocked = await _organizationDomainRepository.HasVerifiedDomainWithBlockClaimedDomainPolicyAsync( + emailDomain, excludeOrganizationId); + if (isDomainBlocked) + { + _logger.LogInformation( + "User registration blocked by domain claim policy. Domain: {Domain}, ExcludedOrgId: {ExcludedOrgId}", + emailDomain, + excludeOrganizationId); + throw new BadRequestException("This email address is claimed by an organization using Bitwarden."); + } + } + + /// + /// We send different welcome emails depending on whether the user is joining a free/family or an enterprise organization. If information to populate the + /// email isn't present we send the standard individual welcome email. + /// + /// Target user for the email + /// this value is nullable + /// + private async Task SendWelcomeEmailAsync(User user, Organization? organization = null) + { + // Check if feature is enabled + // TODO: Remove Feature flag: PM-28221 + if (!_featureService.IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates)) + { + await _mailService.SendWelcomeEmailAsync(user); + return; + } + + // Most emails are probably for non organization users so we default to that experience + if (organization == null) + { + await _mailService.SendIndividualUserWelcomeEmailAsync(user); + } + // We need to make sure that the organization email has the correct data to display otherwise we just send the standard welcome email + else if (!string.IsNullOrEmpty(organization.DisplayName())) + { + // If the organization is Free or Families plan, send families welcome email + if (organization.PlanType.GetProductTier() is ProductTierType.Free or ProductTierType.Families) + { + await _mailService.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, organization.DisplayName()); + } + else + { + await _mailService.SendOrganizationUserWelcomeEmailAsync(user, organization.DisplayName()); + } + } + // If the organization data isn't present send the standard welcome email + else + { + await _mailService.SendIndividualUserWelcomeEmailAsync(user); + } + } + + private async Task GetOrganizationUserOrganization(Guid orgUserId, OrganizationUser? orgUser = null) + { + var organizationUser = orgUser ?? await _organizationUserRepository.GetByIdAsync(orgUserId); + if (organizationUser == null) + { + return null; + } + + return await _organizationRepository.GetByIdAsync(organizationUser.OrganizationId); + } } diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs index 3f89e9ad0e..2e8587eee6 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/SendVerificationEmailForRegistrationCommand.cs @@ -5,6 +5,8 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Tokens; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; @@ -15,29 +17,34 @@ namespace Bit.Core.Auth.UserFeatures.Registration.Implementations; /// public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmailForRegistrationCommand { - + private readonly ILogger _logger; private readonly IUserRepository _userRepository; private readonly GlobalSettings _globalSettings; private readonly IMailService _mailService; private readonly IDataProtectorTokenFactory _tokenDataFactory; private readonly IFeatureService _featureService; + private readonly IOrganizationDomainRepository _organizationDomainRepository; public SendVerificationEmailForRegistrationCommand( + ILogger logger, IUserRepository userRepository, GlobalSettings globalSettings, IMailService mailService, IDataProtectorTokenFactory tokenDataFactory, - IFeatureService featureService) + IFeatureService featureService, + IOrganizationDomainRepository organizationDomainRepository) { + _logger = logger; _userRepository = userRepository; _globalSettings = globalSettings; _mailService = mailService; _tokenDataFactory = tokenDataFactory; _featureService = featureService; + _organizationDomainRepository = organizationDomainRepository; } - public async Task Run(string email, string? name, bool receiveMarketingEmails) + public async Task Run(string email, string? name, bool receiveMarketingEmails, string? fromMarketing) { if (_globalSettings.DisableUserRegistration) { @@ -49,6 +56,20 @@ public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmai throw new ArgumentNullException(nameof(email)); } + // Check if the email domain is blocked by an organization policy + if (_featureService.IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation)) + { + var emailDomain = EmailValidation.GetDomain(email); + + if (await _organizationDomainRepository.HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(emailDomain)) + { + _logger.LogInformation( + "User registration email verification blocked by domain claim policy. Domain: {Domain}", + emailDomain); + throw new BadRequestException("This email address is claimed by an organization using Bitwarden."); + } + } + // Check to see if the user already exists var user = await _userRepository.GetByEmailAsync(email); var userExists = user != null; @@ -71,7 +92,7 @@ public class SendVerificationEmailForRegistrationCommand : ISendVerificationEmai // If the user doesn't exist, create a new EmailVerificationTokenable and send the user // an email with a link to verify their email address var token = GenerateToken(email, name, receiveMarketingEmails); - await _mailService.SendRegistrationVerificationEmailAsync(email, token); + await _mailService.SendRegistrationVerificationEmailAsync(email, token, fromMarketing); } // User exists but we will return a 200 regardless of whether the email was sent or not; so return null diff --git a/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs b/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs index cc86d3d71d..e6c0c1444a 100644 --- a/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs +++ b/src/Core/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQuery.cs @@ -4,16 +4,37 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Premium.Queries; +using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Repositories; +using Bit.Core.Services; namespace Bit.Core.Auth.UserFeatures.TwoFactorAuth; -public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFactorIsEnabledQuery +public class TwoFactorIsEnabledQuery : ITwoFactorIsEnabledQuery { - private readonly IUserRepository _userRepository = userRepository; + private readonly IUserRepository _userRepository; + private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery; + private readonly IFeatureService _featureService; + + public TwoFactorIsEnabledQuery( + IUserRepository userRepository, + IHasPremiumAccessQuery hasPremiumAccessQuery, + IFeatureService featureService) + { + _userRepository = userRepository; + _hasPremiumAccessQuery = hasPremiumAccessQuery; + _featureService = featureService; + } public async Task> TwoFactorIsEnabledAsync(IEnumerable userIds) { + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await TwoFactorIsEnabledVNextAsync(userIds); + } + var result = new List<(Guid userId, bool hasTwoFactor)>(); if (userIds == null || !userIds.Any()) { @@ -36,6 +57,11 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto public async Task> TwoFactorIsEnabledAsync(IEnumerable users) where T : ITwoFactorProvidersUser { + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await TwoFactorIsEnabledVNextAsync(users); + } + var userIds = users .Select(u => u.GetUserId()) .Where(u => u.HasValue) @@ -71,13 +97,134 @@ public class TwoFactorIsEnabledQuery(IUserRepository userRepository) : ITwoFacto return false; } + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + var userEntity = user as User ?? await _userRepository.GetByIdAsync(userId.Value); + if (userEntity == null) + { + throw new NotFoundException(); + } + + return await TwoFactorIsEnabledVNextAsync(userEntity); + } + return await TwoFactorEnabledAsync( - user.GetTwoFactorProviders(), - async () => - { - var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value); - return calcUser?.HasPremiumAccess ?? false; - }); + user.GetTwoFactorProviders(), + async () => + { + var calcUser = await _userRepository.GetCalculatedPremiumAsync(userId.Value); + return calcUser?.HasPremiumAccess ?? false; + }); + } + + private async Task> TwoFactorIsEnabledVNextAsync(IEnumerable userIds) + { + var result = new List<(Guid userId, bool hasTwoFactor)>(); + if (userIds == null || !userIds.Any()) + { + return result; + } + + var users = await _userRepository.GetManyAsync([.. userIds]); + + // Get enabled providers for each user + var usersTwoFactorProvidersMap = users.ToDictionary(u => u.Id, GetEnabledTwoFactorProviders); + + // Bulk fetch premium status only for users who need it (those with only premium providers) + var userIdsNeedingPremium = usersTwoFactorProvidersMap + .Where(kvp => kvp.Value.Any() && kvp.Value.All(TwoFactorProvider.RequiresPremium)) + .Select(kvp => kvp.Key) + .ToList(); + + var premiumStatusMap = userIdsNeedingPremium.Count > 0 + ? await _hasPremiumAccessQuery.HasPremiumAccessAsync(userIdsNeedingPremium) + : new Dictionary(); + + foreach (var user in users) + { + var userTwoFactorProviders = usersTwoFactorProvidersMap[user.Id]; + + if (!userTwoFactorProviders.Any()) + { + result.Add((user.Id, false)); + continue; + } + + // User has providers. If they're in the premium check map, verify premium status + var twoFactorIsEnabled = !premiumStatusMap.TryGetValue(user.Id, out var hasPremium) || hasPremium; + result.Add((user.Id, twoFactorIsEnabled)); + } + + return result; + } + + private async Task> TwoFactorIsEnabledVNextAsync(IEnumerable users) + where T : ITwoFactorProvidersUser + { + var userIds = users + .Select(u => u.GetUserId()) + .Where(u => u.HasValue) + .Select(u => u.Value) + .ToList(); + + var twoFactorResults = await TwoFactorIsEnabledVNextAsync(userIds); + + var result = new List<(T user, bool twoFactorIsEnabled)>(); + + foreach (var user in users) + { + var userId = user.GetUserId(); + if (userId.HasValue) + { + var hasTwoFactor = twoFactorResults.FirstOrDefault(res => res.userId == userId.Value).twoFactorIsEnabled; + result.Add((user, hasTwoFactor)); + } + else + { + result.Add((user, false)); + } + } + + return result; + } + + private async Task TwoFactorIsEnabledVNextAsync(User user) + { + var enabledProviders = GetEnabledTwoFactorProviders(user); + + if (!enabledProviders.Any()) + { + return false; + } + + // If all providers require premium, check if user has premium access + if (enabledProviders.All(TwoFactorProvider.RequiresPremium)) + { + return await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id); + } + + // User has at least one non-premium provider + return true; + } + + /// + /// Gets all enabled two-factor provider types for a user. + /// + /// user with two factor providers + /// list of enabled provider types + private static IList GetEnabledTwoFactorProviders(User user) + { + var providers = user.GetTwoFactorProviders(); + + if (providers == null || providers.Count == 0) + { + return Array.Empty(); + } + + // TODO: PM-21210: In practice we don't save disabled providers to the database, worth looking into. + return (from provider in providers + where provider.Value?.Enabled ?? false + select provider.Key).ToList(); } /// diff --git a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs index 53bd8bdba2..7c50f7f17b 100644 --- a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs +++ b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ - - +using Bit.Core.Auth.Sso; using Bit.Core.Auth.UserFeatures.DeviceTrust; using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Auth.UserFeatures.Registration.Implementations; @@ -29,6 +28,7 @@ public static class UserServiceCollectionExtensions services.AddWebAuthnLoginCommands(); services.AddTdeOffboardingPasswordCommands(); services.AddTwoFactorQueries(); + services.AddSsoQueries(); } public static void AddDeviceTrustCommands(this IServiceCollection services) @@ -69,4 +69,9 @@ public static class UserServiceCollectionExtensions { services.AddScoped(); } + + private static void AddSsoQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/Billing/Constants/BitPayConstants.cs b/src/Core/Billing/Constants/BitPayConstants.cs new file mode 100644 index 0000000000..a1b2ff6f5b --- /dev/null +++ b/src/Core/Billing/Constants/BitPayConstants.cs @@ -0,0 +1,14 @@ +namespace Bit.Core.Billing.Constants; + +public static class BitPayConstants +{ + public static class InvoiceStatuses + { + public const string Complete = "complete"; + } + + public static class PosDataKeys + { + public const string AccountCredit = "accountCredit:1"; + } +} diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 131adfedf8..dc128127ae 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -12,6 +12,12 @@ public static class StripeConstants public const string UnrecognizedLocation = "unrecognized_location"; } + public static class BillingReasons + { + public const string SubscriptionCreate = "subscription_create"; + public const string SubscriptionCycle = "subscription_cycle"; + } + public static class CollectionMethod { public const string ChargeAutomatically = "charge_automatically"; @@ -22,6 +28,8 @@ public static class StripeConstants { public const string LegacyMSPDiscount = "msp-discount-35"; public const string SecretsManagerStandalone = "sm-standalone"; + public const string Milestone2SubscriptionDiscount = "milestone-2c"; + public const string Milestone3SubscriptionDiscount = "milestone-3"; public static class MSPDiscounts { @@ -63,6 +71,7 @@ public static class StripeConstants public const string Region = "region"; public const string RetiredBraintreeCustomerId = "btCustomerId_old"; public const string UserId = "userId"; + public const string StorageReconciled2025 = "storage_reconciled_2025"; } public static class PaymentBehavior diff --git a/src/Core/Billing/Enums/PlanType.cs b/src/Core/Billing/Enums/PlanType.cs index e88a73af16..0f910c4980 100644 --- a/src/Core/Billing/Enums/PlanType.cs +++ b/src/Core/Billing/Enums/PlanType.cs @@ -18,8 +18,8 @@ public enum PlanType : byte EnterpriseAnnually2019 = 5, [Display(Name = "Custom")] Custom = 6, - [Display(Name = "Families")] - FamiliesAnnually = 7, + [Display(Name = "Families 2025")] + FamiliesAnnually2025 = 7, [Display(Name = "Teams (Monthly) 2020")] TeamsMonthly2020 = 8, [Display(Name = "Teams (Annually) 2020")] @@ -48,4 +48,6 @@ public enum PlanType : byte EnterpriseAnnually = 20, [Display(Name = "Teams Starter")] TeamsStarter = 21, + [Display(Name = "Families")] + FamiliesAnnually = 22, } diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs index 7f81bfd33f..2dae0c2025 100644 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -15,7 +15,7 @@ public static class BillingExtensions => planType switch { PlanType.Custom or PlanType.Free => ProductTierType.Free, - PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2019 => ProductTierType.Families, + PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually2019 => ProductTierType.Families, PlanType.TeamsStarter or PlanType.TeamsStarter2023 => ProductTierType.TeamsStarter, _ when planType.ToString().Contains("Teams") => ProductTierType.Teams, _ when planType.ToString().Contains("Enterprise") => ProductTierType.Enterprise, diff --git a/src/Core/Billing/Extensions/InvoiceExtensions.cs b/src/Core/Billing/Extensions/InvoiceExtensions.cs index bb9f7588bf..d62959c09a 100644 --- a/src/Core/Billing/Extensions/InvoiceExtensions.cs +++ b/src/Core/Billing/Extensions/InvoiceExtensions.cs @@ -64,10 +64,12 @@ public static class InvoiceExtensions } } + var tax = invoice.TotalTaxes?.Sum(invoiceTotalTax => invoiceTotalTax.Amount) ?? 0; + // Add fallback tax from invoice-level tax if present and not already included - if (invoice.Tax.HasValue && invoice.Tax.Value > 0) + if (tax > 0) { - var taxAmount = invoice.Tax.Value / 100m; + var taxAmount = tax / 100m; items.Add($"1 × Tax (at ${taxAmount:F2} / month)"); } diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 7aec422a4b..5ceefed603 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -6,6 +6,7 @@ using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Payment; using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Premium.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -31,6 +32,8 @@ public static class ServiceCollectionExtensions services.AddPaymentOperations(); services.AddOrganizationLicenseCommandsQueries(); services.AddPremiumCommands(); + services.AddPremiumQueries(); + services.AddTransient(); services.AddTransient(); services.AddTransient(); services.AddTransient(); @@ -49,4 +52,9 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddTransient(); } + + private static void AddPremiumQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/Billing/Extensions/SubscriptionExtensions.cs b/src/Core/Billing/Extensions/SubscriptionExtensions.cs new file mode 100644 index 0000000000..383bd32d53 --- /dev/null +++ b/src/Core/Billing/Extensions/SubscriptionExtensions.cs @@ -0,0 +1,25 @@ +using Stripe; + +namespace Bit.Core.Billing.Extensions; + +public static class SubscriptionExtensions +{ + /* + * For the time being, this is the simplest migration approach from v45 to v48 as + * we do not support multi-cadence subscriptions. Each subscription item should be on the + * same billing cycle. If this changes, we'll need a significantly more robust approach. + * + * Because we can't guarantee a subscription will have items, this has to be nullable. + */ + public static (DateTime? Start, DateTime? End)? GetCurrentPeriod(this Subscription subscription) + { + var item = subscription.Items?.FirstOrDefault(); + return item is null ? null : (item.CurrentPeriodStart, item.CurrentPeriodEnd); + } + + public static DateTime? GetCurrentPeriodStart(this Subscription subscription) => + subscription.Items?.FirstOrDefault()?.CurrentPeriodStart; + + public static DateTime? GetCurrentPeriodEnd(this Subscription subscription) => + subscription.Items?.FirstOrDefault()?.CurrentPeriodEnd; +} diff --git a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs deleted file mode 100644 index d00b5b46a4..0000000000 --- a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs +++ /dev/null @@ -1,35 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Extensions; - -public static class UpcomingInvoiceOptionsExtensions -{ - /// - /// Attempts to enable automatic tax for given upcoming invoice options. - /// - /// - /// The existing customer to which the upcoming invoice belongs. - /// The existing subscription to which the upcoming invoice belongs. - /// Returns true when successful, false when conditions are not met. - public static bool EnableAutomaticTax( - this UpcomingInvoiceOptions options, - Customer customer, - Subscription subscription) - { - if (subscription != null && subscription.AutomaticTax.Enabled) - { - return false; - } - - // We might only need to check the automatic tax status. - if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) - { - return false; - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - options.SubscriptionDefaultTaxRates = []; - - return true; - } -} diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs index cdfac76614..727bcbc229 100644 --- a/src/Core/Billing/Licenses/LicenseConstants.cs +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -43,6 +43,8 @@ public static class OrganizationLicenseConstants public const string Trial = nameof(Trial); public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies); public const string UseOrganizationDomains = nameof(UseOrganizationDomains); + public const string UseAutomaticUserConfirmation = nameof(UseAutomaticUserConfirmation); + public const string UsePhishingBlocker = nameof(UsePhishingBlocker); } public static class UserLicenseConstants diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs index 1e049d7f03..4a4771857e 100644 --- a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -26,7 +26,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory All { get; set; } = + [ + new() + { + PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, + SponsoredProductTierType = ProductTierType.Families, + SponsoringProductTierType = ProductTierType.Enterprise, + StripePlanId = "2021-family-for-enterprise-annually", + UsersCanSponsor = org => + org.PlanType.GetProductTier() == ProductTierType.Enterprise, + } + ]; + + public static SponsoredPlan Get(PlanSponsorshipType planSponsorshipType) => + All.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType)!; +} diff --git a/src/Core/Billing/Models/StaticStore/Plan.cs b/src/Core/Billing/Models/StaticStore/Plan.cs index 540ea76582..bab64d9879 100644 --- a/src/Core/Billing/Models/StaticStore/Plan.cs +++ b/src/Core/Billing/Models/StaticStore/Plan.cs @@ -43,6 +43,8 @@ public abstract record Plan public SecretsManagerPlanFeatures SecretsManager { get; protected init; } public bool SupportsSecretsManager => SecretsManager != null; + public bool AutomaticUserConfirmation { get; init; } + public bool HasNonSeatBasedPasswordManagerPlan() => PasswordManager is { StripePlanId: not null and not "", StripeSeatPlanId: null or "" }; @@ -95,7 +97,7 @@ public abstract record Plan public decimal PremiumAccessOptionPrice { get; init; } public short? MaxSeats { get; init; } // Storage - public short? BaseStorageGb { get; init; } + public short BaseStorageGb { get; init; } public bool HasAdditionalStorageOption { get; init; } public decimal AdditionalStoragePricePerGb { get; init; } public string StripeStoragePlanId { get; init; } diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs index 041e9bdbad..2a5e786c98 100644 --- a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -3,12 +3,12 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; -using Bit.Core.Services; -using Bit.Core.Utilities; using Microsoft.Extensions.Logging; using OneOf; using Stripe; @@ -54,7 +54,7 @@ public class PreviewOrganizationTaxCommand( switch (purchase) { case { PasswordManager.Sponsored: true }: - var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise); + var sponsoredPlan = SponsoredPlans.Get(PlanSponsorshipType.FamiliesForEnterprise); items.Add(new InvoiceSubscriptionDetailsItemOptions { Price = sponsoredPlan.StripePlanId, @@ -75,7 +75,13 @@ public class PreviewOrganizationTaxCommand( Quantity = purchase.SecretsManager.Seats } ]); - options.Coupon = CouponIDs.SecretsManagerStandalone; + options.Discounts = + [ + new InvoiceDiscountOptions + { + Coupon = CouponIDs.SecretsManagerStandalone + } + ]; break; default: @@ -119,7 +125,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); @@ -135,6 +141,8 @@ public class PreviewOrganizationTaxCommand( var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType); + var quantity = newPlan.HasNonSeatBasedPasswordManagerPlan() ? 1 : 2; + var items = new List { new () @@ -142,7 +150,7 @@ public class PreviewOrganizationTaxCommand( Price = newPlan.HasNonSeatBasedPasswordManagerPlan() ? newPlan.PasswordManager.StripePlanId : newPlan.PasswordManager.StripeSeatPlanId, - Quantity = 2 + Quantity = quantity } }; @@ -157,7 +165,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } else @@ -173,12 +181,15 @@ public class PreviewOrganizationTaxCommand( var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer"] }); if (subscription.Customer.Discount != null) { - options.Coupon = subscription.Customer.Discount.Coupon.Id; + options.Discounts = + [ + new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id } + ]; } var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -194,12 +205,17 @@ public class PreviewOrganizationTaxCommand( ? currentPlan.PasswordManager.StripePlanId : currentPlan.PasswordManager.StripeSeatPlanId]; + var quantity = currentPlan.HasNonSeatBasedPasswordManagerPlan() && + !newPlan.HasNonSeatBasedPasswordManagerPlan() + ? (long)organization.Seats! + : passwordManagerSeats.Quantity; + items.Add(new InvoiceSubscriptionDetailsItemOptions { Price = newPlan.HasNonSeatBasedPasswordManagerPlan() ? newPlan.PasswordManager.StripePlanId : newPlan.PasswordManager.StripeSeatPlanId, - Quantity = passwordManagerSeats.Quantity + Quantity = quantity }); var hasStorage = @@ -243,7 +259,7 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); } }); @@ -262,7 +278,7 @@ public class PreviewOrganizationTaxCommand( return new BadRequest("Organization does not have a subscription."); } - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + var subscription = await stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions { Expand = ["customer.tax_ids"] }); var options = GetBaseOptions(subscription.Customer, @@ -270,7 +286,10 @@ public class PreviewOrganizationTaxCommand( if (subscription.Customer.Discount != null) { - options.Coupon = subscription.Customer.Discount.Coupon.Id; + options.Discounts = + [ + new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id } + ]; } var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); @@ -317,12 +336,12 @@ public class PreviewOrganizationTaxCommand( options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); private static (decimal, decimal) GetAmounts(Invoice invoice) => ( - Convert.ToDecimal(invoice.Tax) / 100, + Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100, Convert.ToDecimal(invoice.Total) / 100); private static InvoiceCreatePreviewOptions GetBaseOptions( diff --git a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs index fde95f2e70..1dfd786210 100644 --- a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs +++ b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs @@ -1,5 +1,7 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; @@ -52,6 +54,12 @@ public class UpdateOrganizationLicenseCommand : IUpdateOrganizationLicenseComman throw new BadRequestException(exception); } + var useAutomaticUserConfirmation = claimsPrincipal? + .GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation) ?? false; + + selfHostedOrganization.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + license.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + await WriteLicenseFileAsync(selfHostedOrganization, license); await UpdateOrganizationAsync(selfHostedOrganization, license); } diff --git a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs index 83789be2f3..584021f22f 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs @@ -143,6 +143,7 @@ public class OrganizationLicense : ILicense public int? SmSeats { get; set; } public int? SmServiceAccounts { get; set; } public bool UseRiskInsights { get; set; } + public bool UsePhishingBlocker { get; set; } // Deprecated. Left for backwards compatibility with old license versions. public bool LimitCollectionCreationDeletion { get; set; } = true; @@ -153,6 +154,7 @@ public class OrganizationLicense : ILicense public LicenseType? LicenseType { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } public string Hash { get; set; } public string Signature { get; set; } public string Token { get; set; } @@ -226,7 +228,9 @@ public class OrganizationLicense : ILicense // any new fields added need to be added here so that they're ignored !p.Name.Equals(nameof(UseRiskInsights)) && !p.Name.Equals(nameof(UseAdminSponsoredFamilies)) && - !p.Name.Equals(nameof(UseOrganizationDomains))) + !p.Name.Equals(nameof(UseOrganizationDomains)) && + !p.Name.Equals(nameof(UseAutomaticUserConfirmation)) && + !p.Name.Equals(nameof(UsePhishingBlocker))) .OrderBy(p => p.Name) .Select(p => $"{p.Name}:{Core.Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Aggregate((c, n) => $"{c}|{n}"); @@ -397,7 +401,6 @@ public class OrganizationLicense : ILicense var installationId = claimsPrincipal.GetValue(nameof(InstallationId)); var licenseKey = claimsPrincipal.GetValue(nameof(LicenseKey)); var enabled = claimsPrincipal.GetValue(nameof(Enabled)); - var planType = claimsPrincipal.GetValue(nameof(PlanType)); var seats = claimsPrincipal.GetValue(nameof(Seats)); var maxCollections = claimsPrincipal.GetValue(nameof(MaxCollections)); var useGroups = claimsPrincipal.GetValue(nameof(UseGroups)); @@ -421,13 +424,20 @@ public class OrganizationLicense : ILicense var smServiceAccounts = claimsPrincipal.GetValue(nameof(SmServiceAccounts)); var useAdminSponsoredFamilies = claimsPrincipal.GetValue(nameof(UseAdminSponsoredFamilies)); var useOrganizationDomains = claimsPrincipal.GetValue(nameof(UseOrganizationDomains)); + var useAutomaticUserConfirmation = claimsPrincipal.GetValue(nameof(UseAutomaticUserConfirmation)); + + var claimedPlanType = claimsPrincipal.GetValue(nameof(PlanType)); + + var planTypesMatch = claimedPlanType == PlanType.FamiliesAnnually + ? organization.PlanType is PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 + : organization.PlanType == claimedPlanType; return issued <= DateTime.UtcNow && expires >= DateTime.UtcNow && installationId == globalSettings.Installation.Id && licenseKey == organization.LicenseKey && enabled == organization.Enabled && - planType == organization.PlanType && + planTypesMatch && seats == organization.Seats && maxCollections == organization.MaxCollections && useGroups == organization.UseGroups && @@ -450,7 +460,8 @@ public class OrganizationLicense : ILicense smSeats == organization.SmSeats && smServiceAccounts == organization.SmServiceAccounts && useAdminSponsoredFamilies == organization.UseAdminSponsoredFamilies && - useOrganizationDomains == organization.UseOrganizationDomains; + useOrganizationDomains == organization.UseOrganizationDomains && + useAutomaticUserConfirmation == organization.UseAutomaticUserConfirmation; } diff --git a/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs b/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs index 2bcd213dbf..fedd0ad78c 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs @@ -1,28 +1,10 @@ namespace Bit.Core.Billing.Organizations.Models; public record OrganizationMetadata( - bool IsEligibleForSelfHost, - bool IsManaged, bool IsOnSecretsManagerStandalone, - bool IsSubscriptionUnpaid, - bool HasSubscription, - bool HasOpenInvoice, - bool IsSubscriptionCanceled, - DateTime? InvoiceDueDate, - DateTime? InvoiceCreatedDate, - DateTime? SubPeriodEndDate, int OrganizationOccupiedSeats) { public static OrganizationMetadata Default => new OrganizationMetadata( false, - false, - false, - false, - false, - false, - false, - null, - null, - null, 0); } diff --git a/src/Core/Billing/Organizations/Models/OrganizationSale.cs b/src/Core/Billing/Organizations/Models/OrganizationSale.cs index f1f3a636b7..a984d5fe71 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationSale.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationSale.cs @@ -9,7 +9,7 @@ namespace Bit.Core.Billing.Organizations.Models; public class OrganizationSale { - private OrganizationSale() { } + internal OrganizationSale() { } public void Deconstruct( out Organization organization, diff --git a/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs b/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs index ee603c67e0..6c1362d1c5 100644 --- a/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs +++ b/src/Core/Billing/Organizations/Models/SponsorOrganizationSubscriptionUpdate.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core.Billing.Models; using Bit.Core.Models.Business; using Stripe; @@ -17,7 +18,7 @@ public class SponsorOrganizationSubscriptionUpdate : SubscriptionUpdate { _existingPlanStripeId = existingPlan.PasswordManager.StripePlanId; _sponsoredPlanStripeId = sponsoredPlan?.StripePlanId - ?? Core.Utilities.StaticStore.SponsoredPlans.FirstOrDefault()?.StripePlanId; + ?? SponsoredPlans.All.FirstOrDefault()?.StripePlanId; _applySponsorship = applySponsorship; } diff --git a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs index f00bc00356..a8a236decc 100644 --- a/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetCloudOrganizationLicenseQuery.cs @@ -22,14 +22,14 @@ public interface IGetCloudOrganizationLicenseQuery public class GetCloudOrganizationLicenseQuery : IGetCloudOrganizationLicenseQuery { private readonly IInstallationRepository _installationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ILicensingService _licensingService; private readonly IProviderRepository _providerRepository; private readonly IFeatureService _featureService; public GetCloudOrganizationLicenseQuery( IInstallationRepository installationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, ILicensingService licensingService, IProviderRepository providerRepository, IFeatureService featureService) diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs new file mode 100644 index 0000000000..493bae2872 --- /dev/null +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs @@ -0,0 +1,93 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Stripe; + +namespace Bit.Core.Billing.Organizations.Queries; + +public interface IGetOrganizationMetadataQuery +{ + Task Run(Organization organization); +} + +public class GetOrganizationMetadataQuery( + IGlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IPricingClient pricingClient, + ISubscriberService subscriberService) : IGetOrganizationMetadataQuery +{ + public async Task Run(Organization organization) + { + if (globalSettings.SelfHosted) + { + return OrganizationMetadata.Default; + } + + var orgOccupiedSeats = await organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + return OrganizationMetadata.Default with + { + OrganizationOccupiedSeats = orgOccupiedSeats.Total + }; + } + + var customer = await subscriberService.GetCustomer(organization); + + var subscription = await subscriberService.GetSubscription(organization, new SubscriptionGetOptions + { + Expand = ["discounts.coupon.applies_to"] + }); + + if (customer == null || subscription == null) + { + return OrganizationMetadata.Default with + { + OrganizationOccupiedSeats = orgOccupiedSeats.Total + }; + } + + var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription); + + return new OrganizationMetadata( + isOnSecretsManagerStandalone, + orgOccupiedSeats.Total); + } + + private async Task IsOnSecretsManagerStandalone( + Organization organization, + Customer? customer, + Subscription? subscription) + { + if (customer == null || subscription == null) + { + return false; + } + + var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + if (!plan.SupportsSecretsManager) + { + return false; + } + + var coupon = subscription.Discounts?.FirstOrDefault(discount => + discount.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone)?.Coupon; + + if (coupon == null) + { + return false; + } + + var subscriptionProductIds = subscription.Items.Data.Select(item => item.Plan.ProductId); + + var couponAppliesTo = coupon.AppliesTo?.Products; + + return subscriptionProductIds.Intersect(couponAppliesTo ?? []).Any(); + } +} diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs index f33814f1cf..af8dfa7aec 100644 --- a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs @@ -2,14 +2,13 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Stripe; using Stripe.Tax; @@ -30,8 +29,8 @@ public interface IGetOrganizationWarningsQuery public class GetOrganizationWarningsQuery( ICurrentContext currentContext, + IHasPaymentMethodQuery hasPaymentMethodQuery, IProviderRepository providerRepository, - ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService) : IGetOrganizationWarningsQuery { @@ -81,15 +80,7 @@ public class GetOrganizationWarningsQuery( return null; } - var customer = subscription.Customer; - - var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(organization); - - var hasPaymentMethod = - !string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || - !string.IsNullOrEmpty(customer.DefaultSourceId) || - hasUnverifiedBankAccount || - customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); + var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization); if (hasPaymentMethod) { @@ -170,17 +161,23 @@ public class GetOrganizationWarningsQuery( if (subscription is { Status: SubscriptionStatus.Trialing or SubscriptionStatus.Active, - LatestInvoice: null or { Status: InvoiceStatus.Paid } - } && (subscription.CurrentPeriodEnd - now).TotalDays <= 14) + LatestInvoice: null or { Status: InvoiceStatus.Paid }, + Items.Data.Count: > 0 + }) { - return new ResellerRenewalWarning + var currentPeriodEnd = subscription.GetCurrentPeriodEnd(); + + if (currentPeriodEnd != null && (currentPeriodEnd.Value - now).TotalDays <= 14) { - Type = "upcoming", - Upcoming = new ResellerRenewalWarning.UpcomingRenewal + return new ResellerRenewalWarning { - RenewalDate = subscription.CurrentPeriodEnd - } - }; + Type = "upcoming", + Upcoming = new ResellerRenewalWarning.UpcomingRenewal + { + RenewalDate = currentPeriodEnd.Value + } + }; + } } if (subscription is @@ -203,7 +200,7 @@ public class GetOrganizationWarningsQuery( // ReSharper disable once InvertIf if (subscription.Status == SubscriptionStatus.PastDue) { - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -259,8 +256,8 @@ public class GetOrganizationWarningsQuery( // Get active and scheduled registrations var registrations = (await Task.WhenAll( - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), - stripeAdapter.TaxRegistrationsListAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Active }), + stripeAdapter.ListTaxRegistrationsAsync(new RegistrationListOptions { Status = TaxRegistrationStatus.Scheduled }))) .SelectMany(registrations => registrations.Data); // Find the matching registration for the customer @@ -287,22 +284,4 @@ public class GetOrganizationWarningsQuery( _ => null }; } - - private async Task HasUnverifiedBankAccountAsync( - Organization organization) - { - var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id); - - if (string.IsNullOrEmpty(setupIntentId)) - { - return false; - } - - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions - { - Expand = ["payment_method"] - }); - - return setupIntent.IsUnverifiedBankAccount(); - } } diff --git a/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs index d34bd86e7b..39d2a789e6 100644 --- a/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/IOrganizationBillingService.cs @@ -56,4 +56,11 @@ public interface IOrganizationBillingService /// Thrown when the is . /// Thrown when no payment method is found for the customer, no plan IDs are provided, or subscription update fails. Task UpdateSubscriptionPlanFrequency(Organization organization, PlanType newPlanType); + + /// + /// Updates the organization name and email on the Stripe customer entry. + /// This only updates Stripe, not the Bitwarden database. + /// + /// The organization to update in Stripe. + Task UpdateOrganizationNameAndEmail(Organization organization); } diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index ce8a9a877b..a1b57c2415 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -6,6 +6,7 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; @@ -13,7 +14,6 @@ using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -27,6 +27,7 @@ namespace Bit.Core.Billing.Organizations.Services; public class OrganizationBillingService( IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, + IHasPaymentMethodQuery hasPaymentMethodQuery, ILogger logger, IOrganizationRepository organizationRepository, IPricingClient pricingClient, @@ -43,19 +44,14 @@ public class OrganizationBillingService( ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType) : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup); - var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); + var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup, customerSetup?.Coupon); if (subscription.Status is StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.Active) { organization.Enabled = true; - organization.ExpirationDate = subscription.CurrentPeriodEnd; + organization.ExpirationDate = subscription.GetCurrentPeriodEnd(); + await organizationRepository.ReplaceAsync(organization); } - - organization.Gateway = GatewayType.Stripe; - organization.GatewayCustomerId = customer.Id; - organization.GatewaySubscriptionId = subscription.Id; - - await organizationRepository.ReplaceAsync(organization); } public async Task GetMetadata(Guid organizationId) @@ -72,56 +68,39 @@ public class OrganizationBillingService( return OrganizationMetadata.Default; } - var isEligibleForSelfHost = await IsEligibleForSelfHostAsync(organization); - - var isManaged = organization.Status == OrganizationStatusType.Managed; var orgOccupiedSeats = await organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) { return OrganizationMetadata.Default with { - IsEligibleForSelfHost = isEligibleForSelfHost, - IsManaged = isManaged, OrganizationOccupiedSeats = orgOccupiedSeats.Total }; } - var customer = await subscriberService.GetCustomer(organization, - new CustomerGetOptions { Expand = ["discount.coupon.applies_to"] }); + var customer = await subscriberService.GetCustomer(organization); - var subscription = await subscriberService.GetSubscription(organization); + var subscription = await subscriberService.GetSubscription(organization, new SubscriptionGetOptions + { + Expand = ["discounts.coupon.applies_to"] + }); if (customer == null || subscription == null) { return OrganizationMetadata.Default with { - IsEligibleForSelfHost = isEligibleForSelfHost, - IsManaged = isManaged + OrganizationOccupiedSeats = orgOccupiedSeats.Total }; } var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription); - var invoice = !string.IsNullOrEmpty(subscription.LatestInvoiceId) - ? await stripeAdapter.InvoiceGetAsync(subscription.LatestInvoiceId, new InvoiceGetOptions()) - : null; - return new OrganizationMetadata( - isEligibleForSelfHost, - isManaged, isOnSecretsManagerStandalone, - subscription.Status == StripeConstants.SubscriptionStatus.Unpaid, - true, - invoice?.Status == StripeConstants.InvoiceStatus.Open, - subscription.Status == StripeConstants.SubscriptionStatus.Canceled, - invoice?.DueDate, - invoice?.Created, - subscription.CurrentPeriodEnd, orgOccupiedSeats.Total); } - public async Task - UpdatePaymentMethod( + public async Task UpdatePaymentMethod( Organization organization, TokenizedPaymentSource tokenizedPaymentSource, TaxInformation taxInformation) @@ -181,7 +160,7 @@ public class OrganizationBillingService( try { // Update the subscription in Stripe - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, updateOptions); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, updateOptions); organization.PlanType = newPlan.Type; await organizationRepository.ReplaceAsync(organization); } @@ -196,6 +175,45 @@ public class OrganizationBillingService( } } + public async Task UpdateOrganizationNameAndEmail(Organization organization) + { + if (string.IsNullOrWhiteSpace(organization.GatewayCustomerId)) + { + logger.LogWarning( + "Organization ({OrganizationId}) has no Stripe customer to update", + organization.Id); + return; + } + + var newDisplayName = organization.DisplayName(); + + // Organization.DisplayName() can return null - handle gracefully + if (string.IsNullOrWhiteSpace(newDisplayName)) + { + logger.LogWarning( + "Organization ({OrganizationId}) has no name to update in Stripe", + organization.Id); + return; + } + + await stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, + new CustomerUpdateOptions + { + Email = organization.BillingEmail, + Description = newDisplayName, + InvoiceSettings = new CustomerInvoiceSettingsOptions + { + // This overwrites the existing custom fields for this organization + CustomFields = [ + new CustomerInvoiceSettingsCustomFieldOptions + { + Name = organization.SubscriberType(), + Value = newDisplayName + }] + }, + }); + } + #region Utilities private async Task CreateCustomerAsync( @@ -209,7 +227,6 @@ public class OrganizationBillingService( var customerCreateOptions = new CustomerCreateOptions { - Coupon = customerSetup.Coupon, Description = organization.DisplayBusinessName(), Email = organization.BillingEmail, Expand = ["tax", "tax_ids"], @@ -272,8 +289,6 @@ public class OrganizationBillingService( ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; - - if (planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && customerSetup.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) { @@ -297,7 +312,7 @@ public class OrganizationBillingService( customerCreateOptions.TaxIdData = [ - new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } + new CustomerTaxIdDataOptions { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } ]; if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) @@ -318,7 +333,7 @@ public class OrganizationBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -352,7 +367,13 @@ public class OrganizationBillingService( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CreateCustomerAsync(customerCreateOptions); + + organization.Gateway = GatewayType.Stripe; + organization.GatewayCustomerId = customer.Id; + await organizationRepository.ReplaceAsync(organization); + + return customer; } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) @@ -397,9 +418,10 @@ public class OrganizationBillingService( } private async Task CreateSubscriptionAsync( - Guid organizationId, + Organization organization, Customer customer, - SubscriptionSetup subscriptionSetup) + SubscriptionSetup subscriptionSetup, + string? coupon) { var plan = await pricingClient.GetPlanOrThrow(subscriptionSetup.PlanType); @@ -462,10 +484,11 @@ public class OrganizationBillingService( { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, + Discounts = !string.IsNullOrEmpty(coupon) ? [new SubscriptionDiscountOptions { Coupon = coupon }] : null, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - ["organizationId"] = organizationId.ToString(), + ["organizationId"] = organization.Id.ToString(), ["trialInitiationPath"] = !string.IsNullOrEmpty(subscriptionSetup.InitiationPath) && subscriptionSetup.InitiationPath.Contains("trial from marketing website") ? "marketing-initiated" @@ -475,9 +498,11 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; - // Only set trial_settings.end_behavior.missing_payment_method to "cancel" if there is no payment method - if (string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) && - !customer.Metadata.ContainsKey(BraintreeCustomerIdKey)) + var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization); + + // Only set trial_settings.end_behavior.missing_payment_method to "cancel" + // if there is no payment method AND there's an actual trial period + if (!hasPaymentMethod && subscriptionCreateOptions.TrialPeriodDays > 0) { subscriptionCreateOptions.TrialSettings = new SubscriptionTrialSettingsOptions { @@ -492,7 +517,13 @@ public class OrganizationBillingService( { subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); + + organization.GatewaySubscriptionId = subscription.Id; + await organizationRepository.ReplaceAsync(organization); + + return subscription; } private async Task GetCustomerWhileEnsuringCorrectTaxExemptionAsync( @@ -515,14 +546,14 @@ public class OrganizationBillingService( customer = customer switch { { Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, TaxExempt = StripeConstants.TaxExempt.Reverse }), { Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: StripeConstants.TaxExempt.Reverse } => await - stripeAdapter.CustomerUpdateAsync(customer.Id, + stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Expand = expansions, @@ -534,16 +565,6 @@ public class OrganizationBillingService( return customer; } - private async Task IsEligibleForSelfHostAsync( - Organization organization) - { - var plans = await pricingClient.ListPlans(); - - var eligibleSelfHostPlans = plans.Where(plan => plan.HasSelfHost).Select(plan => plan.Type); - - return eligibleSelfHostPlans.Contains(organization.PlanType); - } - private async Task IsOnSecretsManagerStandalone( Organization organization, Customer? customer, @@ -561,16 +582,17 @@ public class OrganizationBillingService( return false; } - var hasCoupon = customer.Discount?.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone; + var coupon = subscription.Discounts?.FirstOrDefault(discount => + discount.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone)?.Coupon; - if (!hasCoupon) + if (coupon == null) { return false; } var subscriptionProductIds = subscription.Items.Data.Select(item => item.Plan.ProductId); - var couponAppliesTo = customer.Discount?.Coupon?.AppliesTo?.Products; + var couponAppliesTo = coupon.AppliesTo?.Products; return subscriptionProductIds.Intersect(couponAppliesTo ?? []).Any(); } @@ -590,7 +612,7 @@ public class OrganizationBillingService( } } }; - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, options); + await stripeAdapter.UpdateSubscriptionAsync(organization.GatewaySubscriptionId, options); } } diff --git a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs index a86f0e3ada..cc07f1b5db 100644 --- a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs +++ b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Clients; using Bit.Core.Entities; using Bit.Core.Settings; @@ -9,6 +10,8 @@ using Microsoft.Extensions.Logging; namespace Bit.Core.Billing.Payment.Commands; +using static BitPayConstants; + public interface ICreateBitPayInvoiceForCreditCommand { Task> Run( @@ -31,6 +34,8 @@ public class CreateBitPayInvoiceForCreditCommand( { var (name, email, posData) = GetSubscriberInformation(subscriber); + var notificationUrl = $"{globalSettings.BitPay.NotificationUrl}?key={globalSettings.BitPay.WebhookKey}"; + var invoice = new Invoice { Buyer = new Buyer { Email = email, Name = name }, @@ -38,7 +43,7 @@ public class CreateBitPayInvoiceForCreditCommand( ExtendedNotifications = true, FullNotifications = true, ItemDesc = "Bitwarden", - NotificationUrl = globalSettings.BitPay.NotificationUrl, + NotificationUrl = notificationUrl, PosData = posData, Price = Convert.ToDouble(amount), RedirectUrl = redirectUrl @@ -51,10 +56,10 @@ public class CreateBitPayInvoiceForCreditCommand( private static (string? Name, string? Email, string POSData) GetSubscriberInformation( ISubscriber subscriber) => subscriber switch { - User user => (user.Email, user.Email, $"userId:{user.Id},accountCredit:1"), + User user => (user.Email, user.Email, $"userId:{user.Id},{PosDataKeys.AccountCredit}"), Organization organization => (organization.Name, organization.BillingEmail, - $"organizationId:{organization.Id},accountCredit:1"), - Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},accountCredit:1"), + $"organizationId:{organization.Id},{PosDataKeys.AccountCredit}"), + Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},{PosDataKeys.AccountCredit}"), _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) }; } diff --git a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs index f4eca40cae..daf39fb981 100644 --- a/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdateBillingAddressCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Microsoft.Extensions.Logging; using Stripe; @@ -46,7 +45,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -71,7 +70,7 @@ public class UpdateBillingAddressCommand( BillingAddress billingAddress) { var customer = - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions @@ -92,7 +91,7 @@ public class UpdateBillingAddressCommand( await EnableAutomaticTaxAsync(subscriber, customer); var deleteExistingTaxIds = customer.TaxIds?.Any() ?? false - ? customer.TaxIds.Select(taxId => stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id)).ToList() + ? customer.TaxIds.Select(taxId => stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id)).ToList() : []; if (billingAddress.TaxId == null) @@ -101,12 +100,12 @@ public class UpdateBillingAddressCommand( return BillingAddress.From(customer.Address); } - var updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + var updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value }); if (billingAddress.TaxId.Code == StripeConstants.TaxIdType.SpanishNIF) { - updatedTaxId = await stripeAdapter.TaxIdCreateAsync(customer.Id, + updatedTaxId = await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, @@ -130,7 +129,7 @@ public class UpdateBillingAddressCommand( if (subscription is { AutomaticTax.Enabled: false }) { - await stripeAdapter.SubscriptionUpdateAsync(subscriber.GatewaySubscriptionId, + await stripeAdapter.UpdateSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs index 81206b8032..a5a9e3e9c9 100644 --- a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -56,7 +55,7 @@ public class UpdatePaymentMethodCommand( if (billingAddress != null && customer.Address is not { Country: not null, PostalCode: not null }) { - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions @@ -75,7 +74,7 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var setupIntents = await stripeAdapter.SetupIntentList(new SetupIntentListOptions + var setupIntents = await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { Expand = ["data.payment_method"], PaymentMethod = token @@ -104,9 +103,9 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - var paymentMethod = await stripeAdapter.PaymentMethodAttachAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); + var paymentMethod = await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = customer.Id }); - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { DefaultPaymentMethod = token } @@ -139,7 +138,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomer.Id }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } var payPalAccount = braintreeCustomer.DefaultPaymentMethod as PayPalAccount; @@ -204,7 +203,7 @@ public class UpdatePaymentMethodCommand( [StripeConstants.MetadataKeys.BraintreeCustomerId] = string.Empty }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); } } } diff --git a/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs b/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs new file mode 100644 index 0000000000..5e8ec0484c --- /dev/null +++ b/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs @@ -0,0 +1,11 @@ +namespace Bit.Core.Billing.Payment.Models; + +public record NonTokenizedPaymentMethod +{ + public NonTokenizablePaymentMethodType Type { get; set; } +} + +public enum NonTokenizablePaymentMethodType +{ + AccountCredit, +} diff --git a/src/Core/Billing/Payment/Models/PaymentMethod.cs b/src/Core/Billing/Payment/Models/PaymentMethod.cs new file mode 100644 index 0000000000..b0733da414 --- /dev/null +++ b/src/Core/Billing/Payment/Models/PaymentMethod.cs @@ -0,0 +1,71 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using OneOf; + +namespace Bit.Core.Billing.Payment.Models; + +[JsonConverter(typeof(PaymentMethodJsonConverter))] +public class PaymentMethod(OneOf input) + : OneOfBase(input) +{ + public static implicit operator PaymentMethod(TokenizedPaymentMethod tokenized) => new(tokenized); + public static implicit operator PaymentMethod(NonTokenizedPaymentMethod nonTokenized) => new(nonTokenized); + public bool IsTokenized => IsT0; + public TokenizedPaymentMethod AsTokenized => AsT0; + public bool IsNonTokenized => IsT1; + public NonTokenizedPaymentMethod AsNonTokenized => AsT1; +} + +internal class PaymentMethodJsonConverter : JsonConverter +{ + public override PaymentMethod Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + + if (!element.TryGetProperty("type", out var typeProperty)) + { + throw new JsonException("PaymentMethod requires a 'type' property"); + } + + var type = typeProperty.GetString(); + + + if (Enum.TryParse(type, true, out var tokenizedType) && + Enum.IsDefined(typeof(TokenizablePaymentMethodType), tokenizedType)) + { + var token = element.TryGetProperty("token", out var tokenProperty) ? tokenProperty.GetString() : null; + if (string.IsNullOrEmpty(token)) + { + throw new JsonException("TokenizedPaymentMethod requires a 'token' property"); + } + + return new TokenizedPaymentMethod { Type = tokenizedType, Token = token }; + } + + if (Enum.TryParse(type, true, out var nonTokenizedType) && + Enum.IsDefined(typeof(NonTokenizablePaymentMethodType), nonTokenizedType)) + { + return new NonTokenizedPaymentMethod { Type = nonTokenizedType }; + } + + throw new JsonException($"Unknown payment method type: {type}"); + } + + public override void Write(Utf8JsonWriter writer, PaymentMethod value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + value.Switch( + tokenized => + { + writer.WriteString("type", + tokenized.Type.ToString().ToLowerInvariant() + ); + writer.WriteString("token", tokenized.Token); + }, + nonTokenized => { writer.WriteString("type", nonTokenized.Type.ToString().ToLowerInvariant()); } + ); + + writer.WriteEndObject(); + } +} diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs index 9f9618571e..e03a785278 100644 --- a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; -using Bit.Core.Services; using Braintree; using Microsoft.Extensions.Logging; using Stripe; @@ -53,7 +52,7 @@ public class GetPaymentMethodQuery( if (!string.IsNullOrEmpty(setupIntentId)) { - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); diff --git a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs new file mode 100644 index 0000000000..c972c3fe5f --- /dev/null +++ b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs @@ -0,0 +1,57 @@ +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Stripe; + +namespace Bit.Core.Billing.Payment.Queries; + +using static StripeConstants; + +public interface IHasPaymentMethodQuery +{ + Task Run(ISubscriber subscriber); +} + +public class HasPaymentMethodQuery( + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : IHasPaymentMethodQuery +{ + public async Task Run(ISubscriber subscriber) + { + var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(subscriber); + + var customer = await subscriberService.GetCustomer(subscriber); + + if (customer == null) + { + return hasUnverifiedBankAccount; + } + + return + !string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || + !string.IsNullOrEmpty(customer.DefaultSourceId) || + hasUnverifiedBankAccount || + customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); + } + + private async Task HasUnverifiedBankAccountAsync( + ISubscriber subscriber) + { + var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id); + + if (string.IsNullOrEmpty(setupIntentId)) + { + return false; + } + + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions + { + Expand = ["payment_method"] + }); + + return setupIntent.IsUnverifiedBankAccount(); + } +} diff --git a/src/Core/Billing/Payment/Registrations.cs b/src/Core/Billing/Payment/Registrations.cs index 478673d2fc..89d3778ccd 100644 --- a/src/Core/Billing/Payment/Registrations.cs +++ b/src/Core/Billing/Payment/Registrations.cs @@ -19,5 +19,6 @@ public static class Registrations services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 1227cdc034..ed60e2f11c 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -1,7 +1,11 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; @@ -14,10 +18,12 @@ using Microsoft.Extensions.Logging; using OneOf.Types; using Stripe; using Customer = Stripe.Customer; +using PaymentMethod = Bit.Core.Billing.Payment.Models.PaymentMethod; using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Premium.Commands; +using static StripeConstants; using static Utilities; /// @@ -29,14 +35,14 @@ public interface ICreatePremiumCloudHostedSubscriptionCommand /// /// Creates a premium cloud-hosted subscription for the specified user. /// - /// The user to create the premium subscription for. Must not already be a premium user. + /// The user to create the premium subscription for. Must not yet be a premium user. /// The tokenized payment method containing the payment type and token for billing. /// The billing address information required for tax calculation and customer creation. /// Additional storage in GB beyond the base 1GB included with premium (must be >= 0). /// A billing command result indicating success or failure with appropriate error details. Task> Run( User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress, short additionalStorageGb); } @@ -49,7 +55,10 @@ public class CreatePremiumCloudHostedSubscriptionCommand( ISubscriberService subscriberService, IUserService userService, IPushNotificationService pushNotificationService, - ILogger logger) + ILogger logger, + IPricingClient pricingClient, + IHasPaymentMethodQuery hasPaymentMethodQuery, + IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingCommand(logger), ICreatePremiumCloudHostedSubscriptionCommand { private static readonly List _expand = ["tax"]; @@ -57,7 +66,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( public Task> Run( User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress, short additionalStorageGb) => HandleAsync(async () => { @@ -71,31 +80,69 @@ public class CreatePremiumCloudHostedSubscriptionCommand( return new BadRequest("Additional storage must be greater than 0."); } - var customer = string.IsNullOrEmpty(user.GatewayCustomerId) - ? await CreateCustomerAsync(user, paymentMethod, billingAddress) - : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + + Customer? customer; + + /* + * For a new customer purchasing a new subscription, we attach the payment method while creating the customer. + */ + if (string.IsNullOrEmpty(user.GatewayCustomerId)) + { + customer = await CreateCustomerAsync(user, paymentMethod, billingAddress); + } + /* + * An existing customer without a payment method starting a new subscription indicates a user who previously + * purchased account credit but chose to use a tokenizable payment method to pay for the subscription. In this case, + * we need to add the payment method to their customer first. If the incoming payment method is account credit, + * we can just go straight to fetching the customer since there's no payment method to apply. + */ + else if (paymentMethod.IsTokenized && !await hasPaymentMethodQuery.Run(user)) + { + await updatePaymentMethodCommand.Run(user, paymentMethod.AsTokenized, billingAddress); + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } + else + { + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } customer = await ReconcileBillingLocationAsync(customer, billingAddress); - var subscription = await CreateSubscriptionAsync(user.Id, customer, additionalStorageGb > 0 ? additionalStorageGb : null); + var subscription = await CreateSubscriptionAsync(user.Id, customer, premiumPlan, additionalStorageGb > 0 ? additionalStorageGb : null); - switch (paymentMethod) - { - case { Type: TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete: - case { Type: not TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Active: + paymentMethod.Switch( + tokenized => + { + // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault + switch (tokenized) { - user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; - break; + case { Type: TokenizablePaymentMethodType.PayPal } + when subscription.Status == SubscriptionStatus.Incomplete: + case { Type: not TokenizablePaymentMethodType.PayPal } + when subscription.Status == SubscriptionStatus.Active: + { + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + break; + } } - } + }, + _ => + { + if (subscription.Status != SubscriptionStatus.Active) + { + return; + } + + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + }); user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; user.GatewaySubscriptionId = subscription.Id; - user.MaxStorageGb = (short)(1 + additionalStorageGb); + user.MaxStorageGb = (short)(premiumPlan.Storage.Provided + additionalStorageGb); user.LicenseKey = CoreHelpers.SecureRandomString(20); user.RevisionDate = DateTime.UtcNow; @@ -106,9 +153,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }); private async Task CreateCustomerAsync(User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress) { + if (paymentMethod.IsNonTokenized) + { + _logger.LogError("Cannot create customer for user ({UserID}) using non-tokenized payment method. The customer should already exist", user.Id); + throw new BillingException(); + } + var subscriberName = user.SubscriberName(); var customerCreateOptions = new CustomerCreateOptions { @@ -139,24 +192,25 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, - [StripeConstants.MetadataKeys.UserId] = user.Id.ToString() + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [MetadataKeys.UserId] = user.Id.ToString() }, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; var braintreeCustomerId = ""; - // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault - switch (paymentMethod.Type) + // We have checked that the payment method is tokenized, so we can safely cast it. + var tokenizedPaymentMethod = paymentMethod.AsTokenized; + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -170,26 +224,26 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } case TokenizablePaymentMethodType.Card: { - customerCreateOptions.PaymentMethod = paymentMethod.Token; - customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token; + customerCreateOptions.PaymentMethod = tokenizedPaymentMethod.Token; + customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = tokenizedPaymentMethod.Token; break; } case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.Token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, tokenizedPaymentMethod.Token); customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } default: { - _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.Type.ToString()); + _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, tokenizedPaymentMethod.Type.ToString()); throw new BillingException(); } } try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch { @@ -200,7 +254,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (paymentMethod.Type) + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { @@ -243,22 +297,24 @@ public class CreatePremiumCloudHostedSubscriptionCommand( Expand = _expand, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } private async Task CreateSubscriptionAsync( Guid userId, Customer customer, + Pricing.Premium.Plan premiumPlan, int? storage) { + var subscriptionItemOptionsList = new List { new () { - Price = StripeConstants.Prices.PremiumAnnually, + Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } }; @@ -267,7 +323,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = StripeConstants.Prices.StoragePlanPersonal, + Price = premiumPlan.Storage.StripePriceId, Quantity = storage }); } @@ -280,24 +336,24 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { Enabled = true }, - CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + CollectionMethod = CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.UserId] = userId.ToString() + [MetadataKeys.UserId] = userId.ToString() }, PaymentBehavior = usingPayPal - ? StripeConstants.PaymentBehavior.DefaultIncomplete + ? PaymentBehavior.DefaultIncomplete : null, OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs index a0b4fcabc2..07247c83cb 100644 --- a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs @@ -1,14 +1,12 @@ using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; -using Bit.Core.Services; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using Stripe; namespace Bit.Core.Billing.Premium.Commands; -using static StripeConstants; - public interface IPreviewPremiumTaxCommand { Task> Run( @@ -18,6 +16,7 @@ public interface IPreviewPremiumTaxCommand public class PreviewPremiumTaxCommand( ILogger logger, + IPricingClient pricingClient, IStripeAdapter stripeAdapter) : BaseBillingCommand(logger), IPreviewPremiumTaxCommand { public Task> Run( @@ -25,6 +24,8 @@ public class PreviewPremiumTaxCommand( BillingAddress billingAddress) => HandleAsync<(decimal, decimal)>(async () => { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + var options = new InvoiceCreatePreviewOptions { AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, @@ -41,7 +42,7 @@ public class PreviewPremiumTaxCommand( { Items = [ - new InvoiceSubscriptionDetailsItemOptions { Price = Prices.PremiumAnnually, Quantity = 1 } + new InvoiceSubscriptionDetailsItemOptions { Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } ] } }; @@ -50,16 +51,16 @@ public class PreviewPremiumTaxCommand( { options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions { - Price = Prices.StoragePlanPersonal, + Price = premiumPlan.Storage.StripePriceId, Quantity = additionalStorage }); } - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetAmounts(invoice); }); private static (decimal, decimal) GetAmounts(Invoice invoice) => ( - Convert.ToDecimal(invoice.Tax) / 100, + Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100, Convert.ToDecimal(invoice.Total) / 100); } diff --git a/src/Core/Billing/Premium/Models/UserPremiumAccess.cs b/src/Core/Billing/Premium/Models/UserPremiumAccess.cs new file mode 100644 index 0000000000..639d175d25 --- /dev/null +++ b/src/Core/Billing/Premium/Models/UserPremiumAccess.cs @@ -0,0 +1,29 @@ +namespace Bit.Core.Billing.Premium.Models; + +/// +/// Represents user premium access status from personal subscriptions and organization memberships. +/// +public class UserPremiumAccess +{ + /// + /// The unique identifier for the user. + /// + public Guid Id { get; set; } + + /// + /// Indicates whether the user has a personal premium subscription. + /// This does NOT include premium access from organizations. + /// + public bool PersonalPremium { get; set; } + + /// + /// Indicates whether the user has premium access through any organization membership. + /// This is true if the user is a member of at least one enabled organization that grants premium access to users. + /// + public bool OrganizationPremium { get; set; } + + /// + /// Indicates whether the user has premium access from any source (personal subscription or organization). + /// + public bool HasPremiumAccess => PersonalPremium || OrganizationPremium; +} diff --git a/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs b/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs new file mode 100644 index 0000000000..e90710a9b3 --- /dev/null +++ b/src/Core/Billing/Premium/Queries/HasPremiumAccessQuery.cs @@ -0,0 +1,49 @@ +using Bit.Core.Exceptions; +using Bit.Core.Repositories; + +namespace Bit.Core.Billing.Premium.Queries; + +public class HasPremiumAccessQuery : IHasPremiumAccessQuery +{ + private readonly IUserRepository _userRepository; + + public HasPremiumAccessQuery(IUserRepository userRepository) + { + _userRepository = userRepository; + } + + public async Task HasPremiumAccessAsync(Guid userId) + { + var user = await _userRepository.GetPremiumAccessAsync(userId); + if (user == null) + { + throw new NotFoundException(); + } + + return user.HasPremiumAccess; + } + + public async Task> HasPremiumAccessAsync(IEnumerable userIds) + { + var distinctUserIds = userIds.Distinct().ToList(); + var usersWithPremium = await _userRepository.GetPremiumAccessByIdsAsync(distinctUserIds); + + if (usersWithPremium.Count() != distinctUserIds.Count) + { + throw new NotFoundException(); + } + + return usersWithPremium.ToDictionary(u => u.Id, u => u.HasPremiumAccess); + } + + public async Task HasPremiumFromOrganizationAsync(Guid userId) + { + var user = await _userRepository.GetPremiumAccessAsync(userId); + if (user == null) + { + throw new NotFoundException(); + } + + return user.OrganizationPremium; + } +} diff --git a/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs b/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs new file mode 100644 index 0000000000..e5545b1ade --- /dev/null +++ b/src/Core/Billing/Premium/Queries/IHasPremiumAccessQuery.cs @@ -0,0 +1,30 @@ +namespace Bit.Core.Billing.Premium.Queries; + +/// +/// Centralized query for checking if users have premium access through personal subscriptions or organizations. +/// Note: Different from User.Premium which only checks personal subscriptions. +/// +public interface IHasPremiumAccessQuery +{ + /// + /// Checks if a user has premium access (personal or organization). + /// + /// The user ID to check + /// True if user can access premium features + Task HasPremiumAccessAsync(Guid userId); + + /// + /// Checks premium access for multiple users. + /// + /// The user IDs to check + /// Dictionary mapping user IDs to their premium access status + Task> HasPremiumAccessAsync(IEnumerable userIds); + + /// + /// Checks if a user belongs to any organization that grants premium (enabled org with UsersGetPremium). + /// Returns true regardless of personal subscription. Useful for UI decisions like showing subscription options. + /// + /// The user ID to check + /// True if user is in any organization that grants premium + Task HasPremiumFromOrganizationAsync(Guid userId); +} diff --git a/src/Core/Billing/Pricing/IPricingClient.cs b/src/Core/Billing/Pricing/IPricingClient.cs index bc3f142dda..18588ae432 100644 --- a/src/Core/Billing/Pricing/IPricingClient.cs +++ b/src/Core/Billing/Pricing/IPricingClient.cs @@ -3,12 +3,14 @@ using Bit.Core.Exceptions; using Bit.Core.Models.StaticStore; using Bit.Core.Utilities; -#nullable enable - namespace Bit.Core.Billing.Pricing; +using OrganizationPlan = Plan; +using PremiumPlan = Premium.Plan; + public interface IPricingClient { + // TODO: Rename with Organization focus. /// /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . @@ -16,8 +18,9 @@ public interface IPricingClient /// The type of plan to retrieve. /// A Bitwarden record or null in the case the plan could not be found or the method was executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task GetPlan(PlanType planType); + Task GetPlan(PlanType planType); + // TODO: Rename with Organization focus. /// /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . @@ -26,13 +29,17 @@ public interface IPricingClient /// A Bitwarden record. /// Thrown when the for the provided could not be found or the method was executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task GetPlanOrThrow(PlanType planType); + Task GetPlanOrThrow(PlanType planType); + // TODO: Rename with Organization focus. /// /// Retrieve all the Bitwarden plans. If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . /// /// A list of Bitwarden records or an empty list in the case the method is executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task> ListPlans(); + Task> ListPlans(); + + Task GetAvailablePremiumPlan(); + Task> ListPremiumPlans(); } diff --git a/src/Core/Billing/Pricing/Models/Feature.cs b/src/Core/Billing/Pricing/Organizations/Feature.cs similarity index 69% rename from src/Core/Billing/Pricing/Models/Feature.cs rename to src/Core/Billing/Pricing/Organizations/Feature.cs index ea9da5217d..df10d2bcf8 100644 --- a/src/Core/Billing/Pricing/Models/Feature.cs +++ b/src/Core/Billing/Pricing/Organizations/Feature.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; public class Feature { diff --git a/src/Core/Billing/Pricing/Models/Plan.cs b/src/Core/Billing/Pricing/Organizations/Plan.cs similarity index 94% rename from src/Core/Billing/Pricing/Models/Plan.cs rename to src/Core/Billing/Pricing/Organizations/Plan.cs index 5b4296474b..c533c271cb 100644 --- a/src/Core/Billing/Pricing/Models/Plan.cs +++ b/src/Core/Billing/Pricing/Organizations/Plan.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; public class Plan { diff --git a/src/Core/Billing/Pricing/PlanAdapter.cs b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs similarity index 91% rename from src/Core/Billing/Pricing/PlanAdapter.cs rename to src/Core/Billing/Pricing/Organizations/PlanAdapter.cs index 560987b891..42090a56ca 100644 --- a/src/Core/Billing/Pricing/PlanAdapter.cs +++ b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs @@ -1,8 +1,6 @@ using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing.Models; -using Plan = Bit.Core.Billing.Pricing.Models.Plan; -namespace Bit.Core.Billing.Pricing; +namespace Bit.Core.Billing.Pricing.Organizations; public record PlanAdapter : Core.Models.StaticStore.Plan { @@ -60,6 +58,7 @@ public record PlanAdapter : Core.Models.StaticStore.Plan "enterprise-monthly-2020" => PlanType.EnterpriseMonthly2020, "enterprise-monthly-2023" => PlanType.EnterpriseMonthly2023, "families" => PlanType.FamiliesAnnually, + "families-2025" => PlanType.FamiliesAnnually2025, "families-2019" => PlanType.FamiliesAnnually2019, "free" => PlanType.Free, "teams-annually" => PlanType.TeamsAnnually, @@ -79,7 +78,7 @@ public record PlanAdapter : Core.Models.StaticStore.Plan => planType switch { PlanType.Free => ProductTierType.Free, - PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2019 => ProductTierType.Families, + PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually2019 => ProductTierType.Families, PlanType.TeamsStarter or PlanType.TeamsStarter2023 => ProductTierType.TeamsStarter, _ when planType.ToString().Contains("Teams") => ProductTierType.Teams, _ when planType.ToString().Contains("Enterprise") => ProductTierType.Enterprise, @@ -100,11 +99,19 @@ public record PlanAdapter : Core.Models.StaticStore.Plan _ => true); var baseSeats = GetBaseSeats(plan.Seats); var maxSeats = GetMaxSeats(plan.Seats); - var baseStorageGb = (short?)plan.Storage?.Provided; + var baseStorageGb = (short)(plan.Storage?.Provided ?? 0); var hasAdditionalStorageOption = plan.Storage != null; var additionalStoragePricePerGb = plan.Storage?.Price ?? 0; var stripeStoragePlanId = plan.Storage?.StripePriceId; short? maxCollections = plan.AdditionalData.TryGetValue("passwordManager.maxCollections", out var value) ? short.Parse(value) : null; + var stripePremiumAccessPlanId = + plan.AdditionalData.TryGetValue("premiumAccessAddOnPriceId", out var premiumAccessAddOnPriceIdValue) + ? premiumAccessAddOnPriceIdValue + : null; + var premiumAccessOptionPrice = + plan.AdditionalData.TryGetValue("premiumAccessAddOnPriceAmount", out var premiumAccessAddOnPriceAmountValue) + ? decimal.Parse(premiumAccessAddOnPriceAmountValue) + : 0; return new PasswordManagerPlanFeatures { @@ -122,7 +129,9 @@ public record PlanAdapter : Core.Models.StaticStore.Plan HasAdditionalStorageOption = hasAdditionalStorageOption, AdditionalStoragePricePerGb = additionalStoragePricePerGb, StripeStoragePlanId = stripeStoragePlanId, - MaxCollections = maxCollections + MaxCollections = maxCollections, + StripePremiumAccessPlanId = stripePremiumAccessPlanId, + PremiumAccessOptionPrice = premiumAccessOptionPrice }; } diff --git a/src/Core/Billing/Pricing/Models/Purchasable.cs b/src/Core/Billing/Pricing/Organizations/Purchasable.cs similarity index 99% rename from src/Core/Billing/Pricing/Models/Purchasable.cs rename to src/Core/Billing/Pricing/Organizations/Purchasable.cs index 7cb4ee00c1..f6704394f7 100644 --- a/src/Core/Billing/Pricing/Models/Purchasable.cs +++ b/src/Core/Billing/Pricing/Organizations/Purchasable.cs @@ -2,7 +2,7 @@ using System.Text.Json.Serialization; using OneOf; -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; [JsonConverter(typeof(PurchasableJsonConverter))] public class Purchasable(OneOf input) : OneOfBase(input) diff --git a/src/Core/Billing/Pricing/Premium/Plan.cs b/src/Core/Billing/Pricing/Premium/Plan.cs new file mode 100644 index 0000000000..f377157363 --- /dev/null +++ b/src/Core/Billing/Pricing/Premium/Plan.cs @@ -0,0 +1,10 @@ +namespace Bit.Core.Billing.Pricing.Premium; + +public class Plan +{ + public string Name { get; init; } = null!; + public int? LegacyYear { get; init; } + public bool Available { get; init; } + public Purchasable Seat { get; init; } = null!; + public Purchasable Storage { get; init; } = null!; +} diff --git a/src/Core/Billing/Pricing/Premium/Purchasable.cs b/src/Core/Billing/Pricing/Premium/Purchasable.cs new file mode 100644 index 0000000000..6bf69d9593 --- /dev/null +++ b/src/Core/Billing/Pricing/Premium/Purchasable.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.Billing.Pricing.Premium; + +public class Purchasable +{ + public string StripePriceId { get; init; } = null!; + public decimal Price { get; init; } + public int Provided { get; init; } +} diff --git a/src/Core/Billing/Pricing/PricingClient.cs b/src/Core/Billing/Pricing/PricingClient.cs index a3db8ce07f..ecb85ed7e8 100644 --- a/src/Core/Billing/Pricing/PricingClient.cs +++ b/src/Core/Billing/Pricing/PricingClient.cs @@ -1,37 +1,32 @@ using System.Net; using System.Net.Http.Json; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing.Organizations; using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -using Plan = Bit.Core.Models.StaticStore.Plan; - -#nullable enable namespace Bit.Core.Billing.Pricing; +using OrganizationPlan = Bit.Core.Models.StaticStore.Plan; +using PremiumPlan = Premium.Plan; +using Purchasable = Premium.Purchasable; + public class PricingClient( IFeatureService featureService, GlobalSettings globalSettings, HttpClient httpClient, ILogger logger) : IPricingClient { - public async Task GetPlan(PlanType planType) + public async Task GetPlan(PlanType planType) { if (globalSettings.SelfHosted) { return null; } - var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); - - if (!usePricingService) - { - return StaticStore.GetPlan(planType); - } - var lookupKey = GetLookupKey(planType); if (lookupKey == null) @@ -40,16 +35,14 @@ public class PricingClient( return null; } - var response = await httpClient.GetAsync($"plans/lookup/{lookupKey}"); + var response = await httpClient.GetAsync($"plans/organization/{lookupKey}"); if (response.IsSuccessStatusCode) { - var plan = await response.Content.ReadFromJsonAsync(); - if (plan == null) - { - throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); - } - return new PlanAdapter(plan); + var plan = await response.Content.ReadFromJsonAsync(); + return plan == null + ? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null") + : new PlanAdapter(PreProcessFamiliesPreMigrationPlan(plan)); } if (response.StatusCode == HttpStatusCode.NotFound) @@ -62,49 +55,71 @@ public class PricingClient( message: $"Request to the Pricing Service failed with status code {response.StatusCode}"); } - public async Task GetPlanOrThrow(PlanType planType) + public async Task GetPlanOrThrow(PlanType planType) { var plan = await GetPlan(planType); - if (plan == null) - { - throw new NotFoundException(); - } - - return plan; + return plan ?? throw new NotFoundException($"Could not find plan for type {planType}"); } - public async Task> ListPlans() + public async Task> ListPlans() { if (globalSettings.SelfHosted) { return []; } - var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); - - if (!usePricingService) - { - return StaticStore.Plans.ToList(); - } - - var response = await httpClient.GetAsync("plans"); + var response = await httpClient.GetAsync("plans/organization"); if (response.IsSuccessStatusCode) { - var plans = await response.Content.ReadFromJsonAsync>(); - if (plans == null) - { - throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); - } - return plans.Select(Plan (plan) => new PlanAdapter(plan)).ToList(); + var plans = await response.Content.ReadFromJsonAsync>(); + return plans == null + ? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null") + : plans.Select(OrganizationPlan (plan) => new PlanAdapter(PreProcessFamiliesPreMigrationPlan(plan))).ToList(); } throw new BillingException( message: $"Request to the Pricing Service failed with status {response.StatusCode}"); } - private static string? GetLookupKey(PlanType planType) + public async Task GetAvailablePremiumPlan() + { + var premiumPlans = await ListPremiumPlans(); + + var availablePlan = premiumPlans.FirstOrDefault(premiumPlan => premiumPlan.Available); + + return availablePlan ?? throw new NotFoundException("Could not find available premium plan"); + } + + public async Task> ListPremiumPlans() + { + if (globalSettings.SelfHosted) + { + return []; + } + + var fetchPremiumPriceFromPricingService = + featureService.IsEnabled(FeatureFlagKeys.PM26793_FetchPremiumPriceFromPricingService); + + if (!fetchPremiumPriceFromPricingService) + { + return [CurrentPremiumPlan]; + } + + var response = await httpClient.GetAsync("plans/premium"); + + if (response.IsSuccessStatusCode) + { + var plans = await response.Content.ReadFromJsonAsync>(); + return plans ?? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); + } + + throw new BillingException( + message: $"Request to the Pricing Service failed with status {response.StatusCode}"); + } + + private string? GetLookupKey(PlanType planType) => planType switch { PlanType.EnterpriseAnnually => "enterprise-annually", @@ -116,6 +131,10 @@ public class PricingClient( PlanType.EnterpriseMonthly2020 => "enterprise-monthly-2020", PlanType.EnterpriseMonthly2023 => "enterprise-monthly-2023", PlanType.FamiliesAnnually => "families", + PlanType.FamiliesAnnually2025 => + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3) + ? "families-2025" + : "families", PlanType.FamiliesAnnually2019 => "families-2019", PlanType.Free => "free", PlanType.TeamsAnnually => "teams-annually", @@ -130,4 +149,27 @@ public class PricingClient( PlanType.TeamsStarter2023 => "teams-starter-2023", _ => null }; + + /// + /// Safeguard used until the feature flag is enabled. Pricing service will return the + /// 2025PreMigration plan with "families" lookup key. When that is detected and the FF + /// is still disabled, set the lookup key to families-2025 so PlanAdapter will assign + /// the correct plan. + /// + /// The plan to preprocess + private Plan PreProcessFamiliesPreMigrationPlan(Plan plan) + { + if (plan.LookupKey == "families" && !featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3)) + plan.LookupKey = "families-2025"; + return plan; + } + + private static PremiumPlan CurrentPremiumPlan => new() + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } + }; } diff --git a/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs b/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs deleted file mode 100644 index 65fd7726f8..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs +++ /dev/null @@ -1,26 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public enum ClientMigrationProgress -{ - Started = 1, - MigrationRecordCreated = 2, - SubscriptionEnded = 3, - Completed = 4, - - Reversing = 5, - ResetOrganization = 6, - RecreatedSubscription = 7, - RemovedMigrationRecord = 8, - Reversed = 9 -} - -public class ClientMigrationTracker -{ - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public ClientMigrationProgress Progress { get; set; } = ClientMigrationProgress.Started; -} diff --git a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs b/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs deleted file mode 100644 index 78a2631999..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs +++ /dev/null @@ -1,48 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Billing.Providers.Entities; - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public class ProviderMigrationResult -{ - public Guid ProviderId { get; set; } - public string ProviderName { get; set; } - public string Result { get; set; } - public List Clients { get; set; } -} - -public class ClientMigrationResult -{ - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string Result { get; set; } - public ClientPreviousState PreviousState { get; set; } -} - -public class ClientPreviousState -{ - public ClientPreviousState() { } - - public ClientPreviousState(ClientOrganizationMigrationRecord migrationRecord) - { - PlanType = migrationRecord.PlanType.ToString(); - Seats = migrationRecord.Seats; - MaxStorageGb = migrationRecord.MaxStorageGb; - GatewayCustomerId = migrationRecord.GatewayCustomerId; - GatewaySubscriptionId = migrationRecord.GatewaySubscriptionId; - ExpirationDate = migrationRecord.ExpirationDate; - MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats; - Status = migrationRecord.Status.ToString(); - } - - public string PlanType { get; set; } - public int Seats { get; set; } - public short? MaxStorageGb { get; set; } - public string GatewayCustomerId { get; set; } = null!; - public string GatewaySubscriptionId { get; set; } = null!; - public DateTime? ExpirationDate { get; set; } - public int? MaxAutoscaleSeats { get; set; } - public string Status { get; set; } -} diff --git a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs b/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs deleted file mode 100644 index ba39feab2d..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs +++ /dev/null @@ -1,25 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public enum ProviderMigrationProgress -{ - Started = 1, - NoClients = 2, - ClientsMigrated = 3, - TeamsPlanConfigured = 4, - EnterprisePlanConfigured = 5, - CustomerSetup = 6, - SubscriptionSetup = 7, - CreditApplied = 8, - Completed = 9, -} - -public class ProviderMigrationTracker -{ - public Guid ProviderId { get; set; } - public string ProviderName { get; set; } - public List OrganizationIds { get; set; } - public ProviderMigrationProgress Progress { get; set; } = ProviderMigrationProgress.Started; -} diff --git a/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs b/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs deleted file mode 100644 index 1061c82888..0000000000 --- a/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Bit.Core.Billing.Providers.Migration.Services; -using Bit.Core.Billing.Providers.Migration.Services.Implementations; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Billing.Providers.Migration; - -public static class ServiceCollectionExtensions -{ - public static void AddProviderMigration(this IServiceCollection services) - { - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - } -} diff --git a/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs b/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs deleted file mode 100644 index 70649590df..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Providers.Migration.Models; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IMigrationTrackerCache -{ - Task StartTracker(Provider provider); - Task SetOrganizationIds(Guid providerId, IEnumerable organizationIds); - Task GetTracker(Guid providerId); - Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status); - - Task StartTracker(Guid providerId, Organization organization); - Task GetTracker(Guid providerId, Guid organizationId); - Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status); -} diff --git a/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs b/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs deleted file mode 100644 index a0548277b4..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IOrganizationMigrator -{ - Task Migrate(Guid providerId, Organization organization); -} diff --git a/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs b/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs deleted file mode 100644 index 328c2419f4..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Bit.Core.Billing.Providers.Migration.Models; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IProviderMigrator -{ - Task Migrate(Guid providerId); - - Task GetResult(Guid providerId); -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs deleted file mode 100644 index 1f38b0d111..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs +++ /dev/null @@ -1,110 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Providers.Migration.Models; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class MigrationTrackerDistributedCache( - [FromKeyedServices("persistent")] - IDistributedCache distributedCache) : IMigrationTrackerCache -{ - public async Task StartTracker(Provider provider) => - await SetAsync(new ProviderMigrationTracker - { - ProviderId = provider.Id, - ProviderName = provider.Name - }); - - public async Task SetOrganizationIds(Guid providerId, IEnumerable organizationIds) - { - var tracker = await GetAsync(providerId); - - tracker.OrganizationIds = organizationIds.ToList(); - - await SetAsync(tracker); - } - - public Task GetTracker(Guid providerId) => GetAsync(providerId); - - public async Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status) - { - var tracker = await GetAsync(providerId); - - tracker.Progress = status; - - await SetAsync(tracker); - } - - public async Task StartTracker(Guid providerId, Organization organization) => - await SetAsync(new ClientMigrationTracker - { - ProviderId = providerId, - OrganizationId = organization.Id, - OrganizationName = organization.Name - }); - - public Task GetTracker(Guid providerId, Guid organizationId) => - GetAsync(providerId, organizationId); - - public async Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status) - { - var tracker = await GetAsync(providerId, organizationId); - - tracker.Progress = status; - - await SetAsync(tracker); - } - - private static string GetProviderCacheKey(Guid providerId) => $"provider_{providerId}_migration"; - - private static string GetClientCacheKey(Guid providerId, Guid clientId) => - $"provider_{providerId}_client_{clientId}_migration"; - - private async Task GetAsync(Guid providerId) - { - var cacheKey = GetProviderCacheKey(providerId); - - var json = await distributedCache.GetStringAsync(cacheKey); - - return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize(json); - } - - private async Task GetAsync(Guid providerId, Guid organizationId) - { - var cacheKey = GetClientCacheKey(providerId, organizationId); - - var json = await distributedCache.GetStringAsync(cacheKey); - - return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize(json); - } - - private async Task SetAsync(ProviderMigrationTracker tracker) - { - var cacheKey = GetProviderCacheKey(tracker.ProviderId); - - var json = JsonSerializer.Serialize(tracker); - - await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions - { - SlidingExpiration = TimeSpan.FromMinutes(30) - }); - } - - private async Task SetAsync(ClientMigrationTracker tracker) - { - var cacheKey = GetClientCacheKey(tracker.ProviderId, tracker.OrganizationId); - - var json = JsonSerializer.Serialize(tracker); - - await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions - { - SlidingExpiration = TimeSpan.FromMinutes(30) - }); - } -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs deleted file mode 100644 index 3de49838af..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs +++ /dev/null @@ -1,331 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Providers.Entities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Repositories; -using Bit.Core.Enums; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; -using Plan = Bit.Core.Models.StaticStore.Plan; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class OrganizationMigrator( - IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository, - ILogger logger, - IMigrationTrackerCache migrationTrackerCache, - IOrganizationRepository organizationRepository, - IPricingClient pricingClient, - IStripeAdapter stripeAdapter) : IOrganizationMigrator -{ - private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing"; - - public async Task Migrate(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Starting migration for organization ({OrganizationID})", organization.Id); - - await migrationTrackerCache.StartTracker(providerId, organization); - - await CreateMigrationRecordAsync(providerId, organization); - - await CancelSubscriptionAsync(providerId, organization); - - await UpdateOrganizationAsync(providerId, organization); - } - - #region Steps - - private async Task CreateMigrationRecordAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Creating ClientOrganizationMigrationRecord for organization ({OrganizationID})", organization.Id); - - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord != null) - { - logger.LogInformation( - "CB: ClientOrganizationMigrationRecord already exists for organization ({OrganizationID}), deleting record", - organization.Id); - - await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord); - } - - await clientOrganizationMigrationRecordRepository.CreateAsync(new ClientOrganizationMigrationRecord - { - OrganizationId = organization.Id, - ProviderId = providerId, - PlanType = organization.PlanType, - Seats = organization.Seats ?? 0, - MaxStorageGb = organization.MaxStorageGb, - GatewayCustomerId = organization.GatewayCustomerId!, - GatewaySubscriptionId = organization.GatewaySubscriptionId!, - ExpirationDate = organization.ExpirationDate, - MaxAutoscaleSeats = organization.MaxAutoscaleSeats, - Status = organization.Status - }); - - logger.LogInformation("CB: Created migration record for organization ({OrganizationID})", organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.MigrationRecordCreated); - } - - private async Task CancelSubscriptionAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Cancelling subscription for organization ({OrganizationID})", organization.Id); - - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId); - - if (subscription is - { - Status: - StripeConstants.SubscriptionStatus.Active or - StripeConstants.SubscriptionStatus.PastDue or - StripeConstants.SubscriptionStatus.Trialing - }) - { - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, - new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); - - subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, - new SubscriptionCancelOptions - { - CancellationDetails = new SubscriptionCancellationDetailsOptions - { - Comment = _cancellationComment - }, - InvoiceNow = true, - Prorate = true, - Expand = ["latest_invoice", "test_clock"] - }); - - logger.LogInformation("CB: Cancelled subscription for organization ({OrganizationID})", organization.Id); - - var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; - - var trialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; - - if (!trialing && subscription is { Status: StripeConstants.SubscriptionStatus.Canceled, CancellationDetails.Comment: _cancellationComment }) - { - var latestInvoice = subscription.LatestInvoice; - - if (latestInvoice.Status == "draft") - { - await stripeAdapter.InvoiceFinalizeInvoiceAsync(latestInvoice.Id, - new InvoiceFinalizeOptions { AutoAdvance = true }); - - logger.LogInformation("CB: Finalized prorated invoice for organization ({OrganizationID})", organization.Id); - } - } - } - else - { - logger.LogInformation( - "CB: Did not need to cancel subscription for organization ({OrganizationID}) as it was inactive", - organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.SubscriptionEnded); - } - - private async Task UpdateOrganizationAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Bringing organization ({OrganizationID}) under provider management", - organization.Id); - - var plan = await pricingClient.GetPlanOrThrow(organization.Plan.Contains("Teams") ? PlanType.TeamsMonthly : PlanType.EnterpriseMonthly); - - ResetOrganizationPlan(organization, plan); - organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; - organization.GatewaySubscriptionId = null; - organization.ExpirationDate = null; - organization.MaxAutoscaleSeats = null; - organization.Status = OrganizationStatusType.Managed; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Brought organization ({OrganizationID}) under provider management", - organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.Completed); - } - - #endregion - - #region Reverse - - private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id); - - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord != null) - { - await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord); - - logger.LogInformation( - "CB: Removed migration record for organization ({OrganizationID})", - organization.Id); - } - else - { - logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed); - } - - private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id); - - if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId)) - { - if (string.IsNullOrEmpty(organization.GatewayCustomerId)) - { - logger.LogError( - "CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer", - organization.Id); - - throw new Exception(); - } - - var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, - new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); - - var collectionMethod = - customer.DefaultSource != null || - customer.InvoiceSettings?.DefaultPaymentMethod != null || - customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey) - ? StripeConstants.CollectionMethod.ChargeAutomatically - : StripeConstants.CollectionMethod.SendInvoice; - - var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - - var items = new List - { - new () - { - Price = plan.PasswordManager.StripeSeatPlanId, - Quantity = organization.Seats - } - }; - - if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value) - { - var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value; - - items.Add(new SubscriptionItemOptions - { - Price = plan.PasswordManager.StripeStoragePlanId, - Quantity = additionalStorage - }); - } - - var subscriptionCreateOptions = new SubscriptionCreateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, - Customer = customer.Id, - CollectionMethod = collectionMethod, - DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null, - Items = items, - Metadata = new Dictionary - { - [organization.GatewayIdField()] = organization.Id.ToString() - }, - OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, - TrialPeriodDays = plan.TrialPeriodDays - }; - - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - - organization.GatewaySubscriptionId = subscription.Id; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id); - } - else - { - logger.LogInformation( - "CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists", - organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.RecreatedSubscription); - } - - private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization) - { - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord == null) - { - logger.LogError( - "CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record", - organization.Id); - - throw new Exception(); - } - - var plan = await pricingClient.GetPlanOrThrow(migrationRecord.PlanType); - - ResetOrganizationPlan(organization, plan); - organization.MaxStorageGb = migrationRecord.MaxStorageGb; - organization.ExpirationDate = migrationRecord.ExpirationDate; - organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats; - organization.Status = migrationRecord.Status; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates", - organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.ResetOrganization); - } - - #endregion - - #region Shared - - private static void ResetOrganizationPlan(Organization organization, Plan plan) - { - organization.Plan = plan.Name; - organization.PlanType = plan.Type; - organization.MaxCollections = plan.PasswordManager.MaxCollections; - organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; - organization.UsePolicies = plan.HasPolicies; - organization.UseSso = plan.HasSso; - organization.UseOrganizationDomains = plan.HasOrganizationDomains; - organization.UseGroups = plan.HasGroups; - organization.UseEvents = plan.HasEvents; - organization.UseDirectory = plan.HasDirectory; - organization.UseTotp = plan.HasTotp; - organization.Use2fa = plan.Has2fa; - organization.UseApi = plan.HasApi; - organization.UseResetPassword = plan.HasResetPassword; - organization.SelfHost = plan.HasSelfHost; - organization.UsersGetPremium = plan.UsersGetPremium; - organization.UseCustomPermissions = plan.HasCustomPermissions; - organization.UseScim = plan.HasScim; - organization.UseKeyConnector = plan.HasKeyConnector; - } - - #endregion -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs deleted file mode 100644 index e155b427f1..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs +++ /dev/null @@ -1,436 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Providers.Entities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Models; -using Bit.Core.Billing.Providers.Repositories; -using Bit.Core.Billing.Providers.Services; -using Bit.Core.Enums; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class ProviderMigrator( - IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository, - IOrganizationMigrator organizationMigrator, - ILogger logger, - IMigrationTrackerCache migrationTrackerCache, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IProviderBillingService providerBillingService, - IProviderOrganizationRepository providerOrganizationRepository, - IProviderRepository providerRepository, - IProviderPlanRepository providerPlanRepository, - IStripeAdapter stripeAdapter) : IProviderMigrator -{ - public async Task Migrate(Guid providerId) - { - var provider = await GetProviderAsync(providerId); - - if (provider == null) - { - return; - } - - logger.LogInformation("CB: Starting migration for provider ({ProviderID})", providerId); - - await migrationTrackerCache.StartTracker(provider); - - var organizations = await GetClientsAsync(provider.Id); - - if (organizations.Count == 0) - { - logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients); - - return; - } - - await MigrateClientsAsync(providerId, organizations); - - await ConfigureTeamsPlanAsync(providerId); - - await ConfigureEnterprisePlanAsync(providerId); - - await SetupCustomerAsync(provider); - - await SetupSubscriptionAsync(provider); - - await ApplyCreditAsync(provider); - - await UpdateProviderAsync(provider); - } - - public async Task GetResult(Guid providerId) - { - var providerTracker = await migrationTrackerCache.GetTracker(providerId); - - if (providerTracker == null) - { - return null; - } - - if (providerTracker.Progress == ProviderMigrationProgress.NoClients) - { - return new ProviderMigrationResult - { - ProviderId = providerTracker.ProviderId, - ProviderName = providerTracker.ProviderName, - Result = providerTracker.Progress.ToString() - }; - } - - var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId => - migrationTrackerCache.GetTracker(providerId, organizationId))); - - var migrationRecordLookup = new Dictionary(); - - foreach (var clientTracker in clientTrackers) - { - var migrationRecord = - await clientOrganizationMigrationRecordRepository.GetByOrganizationId(clientTracker.OrganizationId); - - migrationRecordLookup.Add(clientTracker.OrganizationId, migrationRecord); - } - - return new ProviderMigrationResult - { - ProviderId = providerTracker.ProviderId, - ProviderName = providerTracker.ProviderName, - Result = providerTracker.Progress.ToString(), - Clients = clientTrackers.Select(tracker => - { - var foundMigrationRecord = migrationRecordLookup.TryGetValue(tracker.OrganizationId, out var migrationRecord); - return new ClientMigrationResult - { - OrganizationId = tracker.OrganizationId, - OrganizationName = tracker.OrganizationName, - Result = tracker.Progress.ToString(), - PreviousState = foundMigrationRecord ? new ClientPreviousState(migrationRecord) : null - }; - }).ToList(), - }; - } - - #region Steps - - private async Task MigrateClientsAsync(Guid providerId, List organizations) - { - logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId); - - var organizationIds = organizations.Select(organization => organization.Id); - - await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds); - - foreach (var organization in organizations) - { - var tracker = await migrationTrackerCache.GetTracker(providerId, organization.Id); - - if (tracker is not { Progress: ClientMigrationProgress.Completed }) - { - await organizationMigrator.Migrate(providerId, organization); - } - } - - logger.LogInformation("CB: Migrated clients for provider ({ProviderID})", providerId); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, - ProviderMigrationProgress.ClientsMigrated); - } - - private async Task ConfigureTeamsPlanAsync(Guid providerId) - { - logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId); - - var organizations = await GetClientsAsync(providerId); - - var teamsSeats = organizations - .Where(IsTeams) - .Sum(client => client.Seats) ?? 0; - - var teamsProviderPlan = (await providerPlanRepository.GetByProviderId(providerId)) - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly); - - if (teamsProviderPlan == null) - { - await providerPlanRepository.CreateAsync(new ProviderPlan - { - ProviderId = providerId, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = teamsSeats, - PurchasedSeats = 0, - AllocatedSeats = teamsSeats - }); - - logger.LogInformation("CB: Created Teams plan for provider ({ProviderID}) with a seat minimum of {Seats}", - providerId, teamsSeats); - } - else - { - logger.LogInformation("CB: Teams plan already exists for provider ({ProviderID}), updating seat minimum", providerId); - - teamsProviderPlan.SeatMinimum = teamsSeats; - teamsProviderPlan.AllocatedSeats = teamsSeats; - - await providerPlanRepository.ReplaceAsync(teamsProviderPlan); - - logger.LogInformation("CB: Updated Teams plan for provider ({ProviderID}) to seat minimum of {Seats}", - providerId, teamsProviderPlan.SeatMinimum); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.TeamsPlanConfigured); - } - - private async Task ConfigureEnterprisePlanAsync(Guid providerId) - { - logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId); - - var organizations = await GetClientsAsync(providerId); - - var enterpriseSeats = organizations - .Where(IsEnterprise) - .Sum(client => client.Seats) ?? 0; - - var enterpriseProviderPlan = (await providerPlanRepository.GetByProviderId(providerId)) - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly); - - if (enterpriseProviderPlan == null) - { - await providerPlanRepository.CreateAsync(new ProviderPlan - { - ProviderId = providerId, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = enterpriseSeats, - PurchasedSeats = 0, - AllocatedSeats = enterpriseSeats - }); - - logger.LogInformation("CB: Created Enterprise plan for provider ({ProviderID}) with a seat minimum of {Seats}", - providerId, enterpriseSeats); - } - else - { - logger.LogInformation("CB: Enterprise plan already exists for provider ({ProviderID}), updating seat minimum", providerId); - - enterpriseProviderPlan.SeatMinimum = enterpriseSeats; - enterpriseProviderPlan.AllocatedSeats = enterpriseSeats; - - await providerPlanRepository.ReplaceAsync(enterpriseProviderPlan); - - logger.LogInformation("CB: Updated Enterprise plan for provider ({ProviderID}) to seat minimum of {Seats}", - providerId, enterpriseProviderPlan.SeatMinimum); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.EnterprisePlanConfigured); - } - - private async Task SetupCustomerAsync(Provider provider) - { - if (string.IsNullOrEmpty(provider.GatewayCustomerId)) - { - var organizations = await GetClientsAsync(provider.Id); - - var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId)); - - if (sampleOrganization == null) - { - logger.LogInformation( - "CB: Could not find sample organization for provider ({ProviderID}) that has a Stripe customer", - provider.Id); - - return; - } - - var taxInfo = await paymentService.GetTaxInfoAsync(sampleOrganization); - - // Create dummy payment source for legacy migration - this migrator is deprecated and will be removed - var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token"); - - var customer = await providerBillingService.SetupCustomer(provider, null, null); - - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Coupon = StripeConstants.CouponIDs.LegacyMSPDiscount - }); - - provider.GatewayCustomerId = customer.Id; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Setup Stripe customer for provider ({ProviderID})", provider.Id); - } - else - { - logger.LogInformation("CB: Stripe customer already exists for provider ({ProviderID})", provider.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CustomerSetup); - } - - private async Task SetupSubscriptionAsync(Provider provider) - { - if (string.IsNullOrEmpty(provider.GatewaySubscriptionId)) - { - if (!string.IsNullOrEmpty(provider.GatewayCustomerId)) - { - var subscription = await providerBillingService.SetupSubscription(provider); - - provider.GatewaySubscriptionId = subscription.Id; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Setup Stripe subscription for provider ({ProviderID})", provider.Id); - } - else - { - logger.LogInformation( - "CB: Could not set up Stripe subscription for provider ({ProviderID}) with no Stripe customer", - provider.Id); - - return; - } - } - else - { - logger.LogInformation("CB: Stripe subscription already exists for provider ({ProviderID})", provider.Id); - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var enterpriseSeatMinimum = providerPlans - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly)? - .SeatMinimum ?? 0; - - var teamsSeatMinimum = providerPlans - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)? - .SeatMinimum ?? 0; - - var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider, - [ - (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum), - (Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum) - ]); - await providerBillingService.UpdateSeatMinimums(updateSeatMinimumsCommand); - - logger.LogInformation( - "CB: Updated Stripe subscription for provider ({ProviderID}) with current seat minimums", provider.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.SubscriptionSetup); - } - - private async Task ApplyCreditAsync(Provider provider) - { - var organizations = await GetClientsAsync(provider.Id); - - var organizationCustomers = - await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId))); - - var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance); - - if (organizationCancellationCredit != 0) - { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, - new CustomerBalanceTransactionCreateOptions - { - Amount = organizationCancellationCredit, - Currency = "USD", - Description = "Unused, prorated time for client organization subscriptions." - }); - } - - var migrationRecords = await Task.WhenAll(organizations.Select(organization => - clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id))); - - var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord => - migrationRecord.PlanType is - PlanType.EnterpriseAnnually2020 or - PlanType.TeamsAnnually2020); - - var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100; - - if (legacyOrganizationCredit < 0) - { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, - new CustomerBalanceTransactionCreateOptions - { - Amount = legacyOrganizationCredit, - Currency = "USD", - Description = "1 year rebate for legacy client organizations." - }); - } - - logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id); - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied); - } - - private async Task UpdateProviderAsync(Provider provider) - { - provider.Status = ProviderStatusType.Billable; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Completed migration for provider ({ProviderID})", provider.Id); - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.Completed); - } - - #endregion - - #region Utilities - - private async Task> GetClientsAsync(Guid providerId) - { - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - - return (await Task.WhenAll(providerOrganizations.Select(providerOrganization => - organizationRepository.GetByIdAsync(providerOrganization.OrganizationId)))) - .ToList(); - } - - private async Task GetProviderAsync(Guid providerId) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it does not exist", providerId); - - return null; - } - - if (provider.Type != ProviderType.Msp) - { - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not an MSP", providerId); - - return null; - } - - if (provider.Status == ProviderStatusType.Created) - { - return provider; - } - - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not in the 'Created' state", providerId); - - return null; - } - - private static bool IsEnterprise(Organization organization) => organization.Plan.Contains("Enterprise"); - private static bool IsTeams(Organization organization) => organization.Plan.Contains("Teams"); - - #endregion -} diff --git a/src/Core/Billing/Providers/Services/IProviderBillingService.cs b/src/Core/Billing/Providers/Services/IProviderBillingService.cs index 57d68db038..3f5a48e817 100644 --- a/src/Core/Billing/Providers/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Providers/Services/IProviderBillingService.cs @@ -113,4 +113,11 @@ public interface IProviderBillingService TaxInformation taxInformation); Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command); + + /// + /// Updates the provider name and email on the Stripe customer entry. + /// This only updates Stripe, not the Bitwarden database. + /// + /// The provider to update in Stripe. + Task UpdateProviderNameAndEmail(Provider provider); } diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs new file mode 100644 index 0000000000..5ec732920e --- /dev/null +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -0,0 +1,50 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; + +namespace Bit.Core.Billing.Services; + +public interface IStripeAdapter +{ + Task CreateCustomerAsync(CustomerCreateOptions customerCreateOptions); + Task GetCustomerAsync(string id, CustomerGetOptions options = null); + Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null); + Task DeleteCustomerAsync(string id); + Task> ListCustomerPaymentMethodsAsync(string id, CustomerPaymentMethodListOptions options = null); + Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options); + Task CreateSubscriptionAsync(SubscriptionCreateOptions subscriptionCreateOptions); + Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null); + Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null); + Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); + Task UpdateSubscriptionAsync(string id, SubscriptionUpdateOptions options = null); + Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); + Task GetInvoiceAsync(string id, InvoiceGetOptions options); + Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); + Task> SearchInvoiceAsync(InvoiceSearchOptions options); + Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); + Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); + Task SendInvoiceAsync(string id, InvoiceSendOptions options); + Task PayInvoiceAsync(string id, InvoicePayOptions options = null); + Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null); + Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null); + IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options); + IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options); + Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null); + Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null); + Task CreateTaxIdAsync(string id, TaxIdCreateOptions options); + Task DeleteTaxIdAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); + Task> ListChargesAsync(ChargeListOptions options); + Task CreateRefundAsync(RefundCreateOptions options); + Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null); + Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); + Task CreateSetupIntentAsync(SetupIntentCreateOptions options); + Task> ListSetupIntentsAsync(SetupIntentListOptions options); + Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null); + Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null); + Task GetPriceAsync(string id, PriceGetOptions options = null); +} diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Billing/Services/IStripePaymentService.cs similarity index 85% rename from src/Core/Services/IPaymentService.cs rename to src/Core/Billing/Services/IStripePaymentService.cs index e7e848bcba..b948cf6921 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Billing/Services/IStripePaymentService.cs @@ -4,15 +4,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Business; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Billing.Tax.Responses; using Bit.Core.Entities; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services; -public interface IPaymentService +public interface IStripePaymentService { Task CancelAndRecoverChargesAsync(ISubscriber subscriber); Task SponsorOrganizationAsync(Organization org, OrganizationSponsorship sponsorship); @@ -44,8 +42,6 @@ public interface IPaymentService Task GetBillingAsync(ISubscriber subscriber); Task GetBillingHistoryAsync(ISubscriber subscriber); Task GetSubscriptionAsync(ISubscriber subscriber); - Task GetTaxInfoAsync(ISubscriber subscriber); - Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo); Task AddSecretsManagerToSubscription(Organization org, Plan plan, int additionalSmSeats, int additionalServiceAccount); /// /// Secrets Manager Standalone is a discount in Stripe that is used to give an organization access to Secrets Manager. @@ -68,7 +64,4 @@ public interface IPaymentService /// Organization Representation used for Inviting Organization Users /// If the organization has Secrets Manager and has the Standalone Stripe Discount Task HasSecretsManagerStandalone(InviteOrganization organization); - Task PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - Task PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - } diff --git a/src/Core/Billing/Services/IStripeSyncService.cs b/src/Core/Billing/Services/IStripeSyncService.cs new file mode 100644 index 0000000000..b56204cd47 --- /dev/null +++ b/src/Core/Billing/Services/IStripeSyncService.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Billing.Services; + +public interface IStripeSyncService +{ + Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress); +} diff --git a/src/Core/Billing/Services/ISubscriberService.cs b/src/Core/Billing/Services/ISubscriberService.cs index f88727f37b..343a0e4f38 100644 --- a/src/Core/Billing/Services/ISubscriberService.cs +++ b/src/Core/Billing/Services/ISubscriberService.cs @@ -6,7 +6,6 @@ using Bit.Core.Billing.Tax.Models; using Bit.Core.Entities; using Bit.Core.Enums; using Stripe; -using PaymentMethod = Bit.Core.Billing.Models.PaymentMethod; namespace Bit.Core.Billing.Services; @@ -64,16 +63,6 @@ public interface ISubscriberService ISubscriber subscriber, CustomerGetOptions customerGetOptions = null); - /// - /// Retrieves the account credit, a masked representation of the default payment source and the tax information for the - /// provided . This is essentially a consolidated invocation of the - /// and methods with a response that includes the customer's as account credit in order to cut down on Stripe API calls. - /// - /// The subscriber to retrieve payment method for. - /// A containing the subscriber's account credit, payment source and tax information. - Task GetPaymentMethod( - ISubscriber subscriber); - /// /// Retrieves a masked representation of the subscriber's payment source for presentation to a client. /// @@ -107,16 +96,6 @@ public interface ISubscriberService ISubscriber subscriber, SubscriptionGetOptions subscriptionGetOptions = null); - /// - /// Retrieves the 's tax information using their Stripe 's . - /// - /// The subscriber to retrieve the tax information for. - /// A representing the 's tax information. - /// Thrown when the is . - /// This method opts for returning rather than throwing exceptions, making it ideal for surfacing data from API endpoints. - Task GetTaxInformation( - ISubscriber subscriber); - /// /// Attempts to remove a subscriber's saved payment source. If the Stripe representing the /// contains a valid "btCustomerId" key in its property, @@ -147,17 +126,6 @@ public interface ISubscriberService ISubscriber subscriber, TaxInformation taxInformation); - /// - /// Verifies the subscriber's pending bank account using the provided . - /// - /// The subscriber to verify the bank account for. - /// The code attached to a deposit made to the subscriber's bank account in order to ensure they have access to it. - /// Learn more. - /// - Task VerifyBankAccount( - ISubscriber subscriber, - string descriptorCode); - /// /// Validates whether the 's exists in the gateway. /// If the 's is or empty, returns . diff --git a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs index 5a8cf16f5a..16b3f7e0c3 100644 --- a/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs +++ b/src/Core/Billing/Services/Implementations/PaymentHistoryService.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; namespace Bit.Core.Billing.Services.Implementations; @@ -23,7 +22,7 @@ public class PaymentHistoryService( return Array.Empty(); } - var invoices = await stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var invoices = await stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = subscriber.GatewayCustomerId, Limit = pageSize, diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 9db18278b6..9c85971dff 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -3,14 +3,15 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Tax.Models; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Braintree; using Microsoft.Extensions.Logging; @@ -29,7 +30,8 @@ public class PremiumUserBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository) : IPremiumUserBillingService + IUserRepository userRepository, + IPricingClient pricingClient) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { @@ -65,7 +67,7 @@ public class PremiumUserBillingService( } }; - customer = await stripeAdapter.CustomerCreateAsync(options); + customer = await stripeAdapter.CreateCustomerAsync(options); user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; @@ -78,7 +80,7 @@ public class PremiumUserBillingService( Balance = customer.Balance + credit }; - await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } @@ -98,7 +100,9 @@ public class PremiumUserBillingService( */ customer = await ReconcileBillingLocationAsync(customer, customerSetup.TaxInformation); - var subscription = await CreateSubscriptionAsync(user.Id, customer, storage); + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + + var subscription = await CreateSubscriptionAsync(user.Id, customer, premiumPlan, storage); switch (customerSetup.TokenizedPaymentSource) { @@ -108,7 +112,7 @@ public class PremiumUserBillingService( when subscription.Status == StripeConstants.SubscriptionStatus.Active: { user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); break; } } @@ -116,6 +120,7 @@ public class PremiumUserBillingService( user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; user.GatewaySubscriptionId = subscription.Id; + user.MaxStorageGb = (short)(premiumPlan.Storage.Provided + (storage ?? 0)); await userRepository.ReplaceAsync(user); } @@ -221,7 +226,7 @@ public class PremiumUserBillingService( case PaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) + (await stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = paymentMethodToken })) .FirstOrDefault(); if (setupIntent == null) @@ -254,7 +259,7 @@ public class PremiumUserBillingService( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + return await stripeAdapter.CreateCustomerAsync(customerCreateOptions); } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) @@ -298,13 +303,15 @@ public class PremiumUserBillingService( private async Task CreateSubscriptionAsync( Guid userId, Customer customer, + Pricing.Premium.Plan premiumPlan, int? storage) { + var subscriptionItemOptionsList = new List { new () { - Price = StripeConstants.Prices.PremiumAnnually, + Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } }; @@ -313,7 +320,7 @@ public class PremiumUserBillingService( { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = StripeConstants.Prices.StoragePlanPersonal, + Price = premiumPlan.Storage.StripePriceId, Quantity = storage }); } @@ -339,11 +346,11 @@ public class PremiumUserBillingService( OffSession = true }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + var subscription = await stripeAdapter.CreateSubscriptionAsync(subscriptionCreateOptions); if (usingPayPal) { - await stripeAdapter.InvoiceUpdateAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(subscription.LatestInvoiceId, new InvoiceUpdateOptions { AutoAdvance = false }); @@ -379,6 +386,6 @@ public class PremiumUserBillingService( } }; - return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); + return await stripeAdapter.UpdateCustomerAsync(customer.Id, options); } } diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs new file mode 100644 index 0000000000..cdc7645042 --- /dev/null +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -0,0 +1,209 @@ +// FIXME: Update this file to be null safe and then delete the line below + +#nullable disable + +using Bit.Core.Models.BitStripe; +using Stripe; +using Stripe.Tax; +using Stripe.TestHelpers; +using CustomerService = Stripe.CustomerService; +using RefundService = Stripe.RefundService; + +namespace Bit.Core.Billing.Services.Implementations; + +public class StripeAdapter : IStripeAdapter +{ + private readonly CustomerService _customerService; + private readonly SubscriptionService _subscriptionService; + private readonly InvoiceService _invoiceService; + private readonly PaymentMethodService _paymentMethodService; + private readonly TaxIdService _taxIdService; + private readonly ChargeService _chargeService; + private readonly RefundService _refundService; + private readonly CardService _cardService; + private readonly BankAccountService _bankAccountService; + private readonly PriceService _priceService; + private readonly SetupIntentService _setupIntentService; + private readonly TestClockService _testClockService; + private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; + private readonly RegistrationService _taxRegistrationService; + + public StripeAdapter() + { + _customerService = new CustomerService(); + _subscriptionService = new SubscriptionService(); + _invoiceService = new InvoiceService(); + _paymentMethodService = new PaymentMethodService(); + _taxIdService = new TaxIdService(); + _chargeService = new ChargeService(); + _refundService = new RefundService(); + _cardService = new CardService(); + _bankAccountService = new BankAccountService(); + _priceService = new PriceService(); + _setupIntentService = new SetupIntentService(); + _testClockService = new TestClockService(); + _customerBalanceTransactionService = new CustomerBalanceTransactionService(); + _taxRegistrationService = new RegistrationService(); + } + + /************** + ** CUSTOMER ** + **************/ + public Task CreateCustomerAsync(CustomerCreateOptions options) => + _customerService.CreateAsync(options); + + public Task DeleteCustomerDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => + _customerService.DeleteDiscountAsync(customerId, options); + + public Task GetCustomerAsync(string id, CustomerGetOptions options = null) => + _customerService.GetAsync(id, options); + + public Task UpdateCustomerAsync(string id, CustomerUpdateOptions options = null) => + _customerService.UpdateAsync(id, options); + + public Task DeleteCustomerAsync(string id) => + _customerService.DeleteAsync(id); + + public async Task> ListCustomerPaymentMethodsAsync(string id, + CustomerPaymentMethodListOptions options = null) + { + var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); + return paymentMethods.Data; + } + + public Task CreateCustomerBalanceTransactionAsync(string customerId, + CustomerBalanceTransactionCreateOptions options) => + _customerBalanceTransactionService.CreateAsync(customerId, options); + + /****************** + ** SUBSCRIPTION ** + ******************/ + public Task CreateSubscriptionAsync(SubscriptionCreateOptions options) => + _subscriptionService.CreateAsync(options); + + public Task GetSubscriptionAsync(string id, SubscriptionGetOptions options = null) => + _subscriptionService.GetAsync(id, options); + + public Task UpdateSubscriptionAsync(string id, + SubscriptionUpdateOptions options = null) => + _subscriptionService.UpdateAsync(id, options); + + public Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null) => + _subscriptionService.CancelAsync(id, options); + + /************* + ** INVOICE ** + *************/ + public Task GetInvoiceAsync(string id, InvoiceGetOptions options) => + _invoiceService.GetAsync(id, options); + + public async Task> ListInvoicesAsync(StripeInvoiceListOptions options) + { + if (!options.SelectAll) + { + return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; + } + + options.Limit = 100; + + var invoices = new List(); + + await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) + { + invoices.Add(invoice); + } + + return invoices; + } + + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => + _invoiceService.CreatePreviewAsync(options); + + public async Task> SearchInvoiceAsync(InvoiceSearchOptions options) => + (await _invoiceService.SearchAsync(options)).Data; + + public Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options) => + _invoiceService.UpdateAsync(id, options); + + public Task FinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) => + _invoiceService.FinalizeInvoiceAsync(id, options); + + public Task SendInvoiceAsync(string id, InvoiceSendOptions options) => + _invoiceService.SendInvoiceAsync(id, options); + + public Task PayInvoiceAsync(string id, InvoicePayOptions options = null) => + _invoiceService.PayAsync(id, options); + + public Task DeleteInvoiceAsync(string id, InvoiceDeleteOptions options = null) => + _invoiceService.DeleteAsync(id, options); + + public Task VoidInvoiceAsync(string id, InvoiceVoidOptions options = null) => + _invoiceService.VoidInvoiceAsync(id, options); + + /******************** + ** PAYMENT METHOD ** + ********************/ + public IEnumerable ListPaymentMethodsAutoPaging(PaymentMethodListOptions options) => + _paymentMethodService.ListAutoPaging(options); + + public IAsyncEnumerable ListPaymentMethodsAutoPagingAsync(PaymentMethodListOptions options) + => _paymentMethodService.ListAutoPagingAsync(options); + + public Task AttachPaymentMethodAsync(string id, PaymentMethodAttachOptions options = null) => + _paymentMethodService.AttachAsync(id, options); + + public Task DetachPaymentMethodAsync(string id, PaymentMethodDetachOptions options = null) => + _paymentMethodService.DetachAsync(id, options); + + /************ + ** TAX ID ** + ************/ + public Task CreateTaxIdAsync(string id, TaxIdCreateOptions options) => + _taxIdService.CreateAsync(id, options); + + public Task DeleteTaxIdAsync(string customerId, string taxIdId, + TaxIdDeleteOptions options = null) => + _taxIdService.DeleteAsync(customerId, taxIdId, options); + + /****************** + ** BANK ACCOUNT ** + ******************/ + public Task DeleteBankAccountAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) => + _bankAccountService.DeleteAsync(customerId, bankAccount, options); + + /*********** + ** PRICE ** + ***********/ + public Task GetPriceAsync(string id, PriceGetOptions options = null) => + _priceService.GetAsync(id, options); + + /****************** + ** SETUP INTENT ** + ******************/ + public Task CreateSetupIntentAsync(SetupIntentCreateOptions options) => + _setupIntentService.CreateAsync(options); + + public async Task> ListSetupIntentsAsync(SetupIntentListOptions options) => + (await _setupIntentService.ListAsync(options)).Data; + + public Task CancelSetupIntentAsync(string id, SetupIntentCancelOptions options = null) => + _setupIntentService.CancelAsync(id, options); + + public Task GetSetupIntentAsync(string id, SetupIntentGetOptions options = null) => + _setupIntentService.GetAsync(id, options); + + /******************* + ** MISCELLANEOUS ** + *******************/ + public Task> ListChargesAsync(ChargeListOptions options) => + _chargeService.ListAsync(options); + + public Task> ListTaxRegistrationsAsync(RegistrationListOptions options = null) => + _taxRegistrationService.ListAsync(options); + + public Task CreateRefundAsync(RefundCreateOptions options) => + _refundService.CreateAsync(options); + + public Task DeleteCardAsync(string customerId, string cardId, CardDeleteOptions options = null) => + _cardService.DeleteAsync(customerId, cardId, options); +} diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Billing/Services/Implementations/StripePaymentService.cs similarity index 55% rename from src/Core/Services/Implementations/StripePaymentService.cs rename to src/Core/Billing/Services/Implementations/StripePaymentService.cs index 5b68906d8a..ffc18aa748 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Billing/Services/Implementations/StripePaymentService.cs @@ -9,9 +9,6 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Requests; -using Bit.Core.Billing.Tax.Responses; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -24,9 +21,9 @@ using Stripe; using PaymentMethod = Stripe.PaymentMethod; using StaticStore = Bit.Core.Models.StaticStore; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; -public class StripePaymentService : IPaymentService +public class StripePaymentService : IStripePaymentService { private const string SecretsManagerStandaloneDiscountId = "sm-standalone"; @@ -35,8 +32,6 @@ public class StripePaymentService : IPaymentService private readonly Braintree.IBraintreeGateway _btGateway; private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; - private readonly IFeatureService _featureService; - private readonly ITaxService _taxService; private readonly IPricingClient _pricingClient; public StripePaymentService( @@ -45,8 +40,6 @@ public class StripePaymentService : IPaymentService IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, - IFeatureService featureService, - ITaxService taxService, IPricingClient pricingClient) { _transactionRepository = transactionRepository; @@ -54,8 +47,6 @@ public class StripePaymentService : IPaymentService _stripeAdapter = stripeAdapter; _btGateway = braintreeGateway; _globalSettings = globalSettings; - _featureService = featureService; - _taxService = taxService; _pricingClient = pricingClient; } @@ -65,19 +56,20 @@ public class StripePaymentService : IPaymentService bool applySponsorship) { var existingPlan = await _pricingClient.GetPlanOrThrow(org.PlanType); - var sponsoredPlan = sponsorship?.PlanSponsorshipType != null ? - Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : - null; - var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); + var sponsoredPlan = sponsorship?.PlanSponsorshipType != null + ? SponsoredPlans.Get(sponsorship.PlanSponsorshipType.Value) + : null; + var subscriptionUpdate = + new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true); - var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); - org.ExpirationDate = sub.CurrentPeriodEnd; + var sub = await _stripeAdapter.GetSubscriptionAsync(org.GatewaySubscriptionId); + org.ExpirationDate = sub.GetCurrentPeriodEnd(); if (sponsorship is not null) { - sponsorship.ValidUntil = sub.CurrentPeriodEnd; + sponsorship.ValidUntil = sub.GetCurrentPeriodEnd(); } } @@ -92,7 +84,7 @@ public class StripePaymentService : IPaymentService { // remember, when in doubt, throw var subGetOptions = new SubscriptionGetOptions { Expand = ["customer.tax", "customer.tax_ids"] }; - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subGetOptions); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subGetOptions); if (sub == null) { throw new GatewayException("Subscription not found."); @@ -100,7 +92,8 @@ public class StripePaymentService : IPaymentService if (sub.Status == SubscriptionStatuses.Canceled) { - throw new BadRequestException("You do not have an active subscription. Reinstate your subscription to make changes."); + throw new BadRequestException( + "You do not have an active subscription. Reinstate your subscription to make changes."); } var existingCoupon = sub.Customer.Discount?.Coupon?.Id; @@ -114,7 +107,7 @@ public class StripePaymentService : IPaymentService var subUpdateOptions = new SubscriptionUpdateOptions { Items = updatedItemOptions, - ProrationBehavior = invoiceNow ? Constants.AlwaysInvoice : Constants.CreateProrations, + ProrationBehavior = invoiceNow ? Core.Constants.AlwaysInvoice : Core.Constants.CreateProrations, DaysUntilDue = daysUntilDue ?? 1, CollectionMethod = "send_invoice" }; @@ -128,11 +121,11 @@ public class StripePaymentService : IPaymentService { if (sub.Customer is { - Address.Country: not Constants.CountryAbbreviations.UnitedStates, + Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not StripeConstants.TaxExempt.Reverse }) { - await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, + await _stripeAdapter.UpdateCustomerAsync(sub.CustomerId, new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); } @@ -148,9 +141,9 @@ public class StripePaymentService : IPaymentService string paymentIntentClientSecret = null; try { - var subResponse = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, subUpdateOptions); + var subResponse = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, subUpdateOptions); - var invoice = await _stripeAdapter.InvoiceGetAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); + var invoice = await _stripeAdapter.GetInvoiceAsync(subResponse?.LatestInvoiceId, new InvoiceGetOptions()); if (invoice == null) { throw new BadRequestException("Unable to locate draft invoice for subscription update."); @@ -169,9 +162,9 @@ public class StripePaymentService : IPaymentService } else { - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(subResponse.LatestInvoiceId, + invoice = await _stripeAdapter.FinalizeInvoiceAsync(subResponse.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = false, }); - await _stripeAdapter.InvoiceSendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); + await _stripeAdapter.SendInvoiceAsync(invoice.Id, new InvoiceSendOptions()); paymentIntentClientSecret = null; } } @@ -179,7 +172,7 @@ public class StripePaymentService : IPaymentService catch { // Need to revert the subscription - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { Items = subscriptionUpdate.RevertItemsOptions(sub), // This proration behavior prevents a false "credit" from @@ -191,36 +184,42 @@ public class StripePaymentService : IPaymentService throw; } } - else if (!invoice.Paid) + else if (invoice.Status != StripeConstants.InvoiceStatus.Paid) { // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h - invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); + invoice = await _stripeAdapter.PayInvoiceAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } - } finally { // Change back the subscription collection method and/or days until due if (collectionMethod != "send_invoice" || daysUntilDue == null) { - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions - { - CollectionMethod = collectionMethod, - DaysUntilDue = daysUntilDue, - }); + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, + new SubscriptionUpdateOptions + { + CollectionMethod = collectionMethod, + DaysUntilDue = daysUntilDue, + }); } - var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(sub.CustomerId); var newCoupon = customer.Discount?.Coupon?.Id; if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon)) { // Re-add the lost coupon due to the update. - await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, new CustomerUpdateOptions + await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { - Coupon = existingCoupon + Discounts = + [ + new SubscriptionDiscountOptions + { + Coupon = existingCoupon + } + ] }); } } @@ -285,7 +284,7 @@ public class StripePaymentService : IPaymentService { if (!string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) { - await _stripeAdapter.SubscriptionCancelAsync(subscriber.GatewaySubscriptionId, + await _stripeAdapter.CancelSubscriptionAsync(subscriber.GatewaySubscriptionId, new SubscriptionCancelOptions()); } @@ -294,7 +293,7 @@ public class StripePaymentService : IPaymentService return; } - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); if (customer == null) { return; @@ -319,7 +318,7 @@ public class StripePaymentService : IPaymentService } else { - var charges = await _stripeAdapter.ChargeListAsync(new ChargeListOptions + var charges = await _stripeAdapter.ListChargesAsync(new ChargeListOptions { Customer = subscriber.GatewayCustomerId }); @@ -328,12 +327,12 @@ public class StripePaymentService : IPaymentService { foreach (var charge in charges.Data.Where(c => c.Captured && !c.Refunded)) { - await _stripeAdapter.RefundCreateAsync(new RefundCreateOptions { Charge = charge.Id }); + await _stripeAdapter.CreateRefundAsync(new RefundCreateOptions { Charge = charge.Id }); } } } - await _stripeAdapter.CustomerDeleteAsync(subscriber.GatewayCustomerId); + await _stripeAdapter.DeleteCustomerAsync(subscriber.GatewayCustomerId); } public async Task PayInvoiceAfterSubscriptionChangeAsync(ISubscriber subscriber, Invoice invoice) @@ -341,7 +340,7 @@ public class StripePaymentService : IPaymentService var customerOptions = new CustomerGetOptions(); customerOptions.AddExpand("default_source"); customerOptions.AddExpand("invoice_settings.default_payment_method"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + var customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); string paymentIntentClientSecret = null; @@ -352,7 +351,7 @@ public class StripePaymentService : IPaymentService { var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; var hasDefaultValidSource = customer.DefaultSource != null && - (customer.DefaultSource is Card || customer.DefaultSource is BankAccount); + (customer.DefaultSource is Card || customer.DefaultSource is BankAccount); if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) { cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; @@ -361,16 +360,15 @@ public class StripePaymentService : IPaymentService // We're going to delete this draft invoice, it can't be paid try { - await _stripeAdapter.InvoiceDeleteAsync(invoice.Id); + await _stripeAdapter.DeleteInvoiceAsync(invoice.Id); } catch { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions - { - AutoAdvance = false - }); - await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); + await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false }); + await _stripeAdapter.VoidInvoiceAsync(invoice.Id); } + throw new BadRequestException("No payment method is available."); } } @@ -381,14 +379,9 @@ public class StripePaymentService : IPaymentService { // Finalize the invoice (from Draft) w/o auto-advance so we // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - var invoicePayOptions = new InvoicePayOptions - { - PaymentMethod = cardPaymentMethodId, - }; + invoice = await _stripeAdapter.FinalizeInvoiceAsync(invoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false, }); + var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, }; if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) { invoicePayOptions.PaidOutOfBand = true; @@ -403,13 +396,15 @@ public class StripePaymentService : IPaymentService SubmitForSettlement = true, PayPal = new Braintree.TransactionOptionsPayPalRequest { - CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}" + CustomField = + $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}" } }, CustomFields = new Dictionary { [subscriber.BraintreeIdField()] = subscriber.Id.ToString(), - [subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion + [subscriber.BraintreeCloudRegionField()] = + _globalSettings.BaseServiceUri.CloudRegion } }); @@ -419,7 +414,7 @@ public class StripePaymentService : IPaymentService } braintreeTransaction = transactionResult.Target; - invoice = await _stripeAdapter.InvoiceUpdateAsync(invoice.Id, new InvoiceUpdateOptions + invoice = await _stripeAdapter.UpdateInvoiceAsync(invoice.Id, new InvoiceUpdateOptions { Metadata = new Dictionary { @@ -433,7 +428,7 @@ public class StripePaymentService : IPaymentService try { - invoice = await _stripeAdapter.InvoicePayAsync(invoice.Id, invoicePayOptions); + invoice = await _stripeAdapter.PayInvoiceAsync(invoice.Id, invoicePayOptions); } catch (StripeException e) { @@ -442,9 +437,9 @@ public class StripePaymentService : IPaymentService { // SCA required, get intent client secret var invoiceGetOptions = new InvoiceGetOptions(); - invoiceGetOptions.AddExpand("payment_intent"); - invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); - paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; + invoiceGetOptions.AddExpand("confirmation_secret"); + invoice = await _stripeAdapter.GetInvoiceAsync(invoice.Id, invoiceGetOptions); + paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret; } else { @@ -458,6 +453,7 @@ public class StripePaymentService : IPaymentService { await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); } + if (invoice != null) { if (invoice.Status == "paid") @@ -466,7 +462,7 @@ public class StripePaymentService : IPaymentService return paymentIntentClientSecret; } - invoice = await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); + invoice = await _stripeAdapter.VoidInvoiceAsync(invoice.Id, new InvoiceVoidOptions()); // HACK: Workaround for customer balance credit if (invoice.StartingBalance < 0) @@ -474,15 +470,13 @@ public class StripePaymentService : IPaymentService // Customer had a balance applied to this invoice. Since we can't fully trust Stripe to // credit it back to the customer (even though their docs claim they will), we need to // check that balance against the current customer balance and determine if it needs to be re-applied - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerOptions); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerOptions); // Assumption: Customer balance should now be $0, otherwise payment would not have failed. if (customer.Balance == 0) { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Balance = invoice.StartingBalance - }); + await _stripeAdapter.UpdateCustomerAsync(customer.Id, + new CustomerUpdateOptions { Balance = invoice.StartingBalance }); } } } @@ -496,6 +490,7 @@ public class StripePaymentService : IPaymentService // Let the caller perform any subscription change cleanup throw; } + return paymentIntentClientSecret; } @@ -511,7 +506,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -526,10 +521,10 @@ public class StripePaymentService : IPaymentService try { - var canceledSub = endOfPeriod ? - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : - await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions()); + var canceledSub = endOfPeriod + ? await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, + new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) + : await _stripeAdapter.CancelSubscriptionAsync(sub.Id, new SubscriptionCancelOptions()); if (!canceledSub.CanceledAt.HasValue) { throw new GatewayException("Unable to cancel subscription."); @@ -556,7 +551,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("No subscription."); } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + var sub = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); if (sub == null) { throw new GatewayException("Subscription was not found."); @@ -568,7 +563,7 @@ public class StripePaymentService : IPaymentService throw new GatewayException("Subscription is not marked for cancellation."); } - var updatedSub = await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + var updatedSub = await _stripeAdapter.UpdateSubscriptionAsync(sub.Id, new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); if (updatedSub.CanceledAt.HasValue) { @@ -580,14 +575,14 @@ public class StripePaymentService : IPaymentService { Customer customer = null; var customerExists = subscriber.Gateway == GatewayType.Stripe && - !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); + !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); if (customerExists) { - customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + customer = await _stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); } else { - customer = await _stripeAdapter.CustomerCreateAsync(new CustomerCreateOptions + customer = await _stripeAdapter.CreateCustomerAsync(new CustomerCreateOptions { Email = subscriber.BillingEmailAddress(), Description = subscriber.BillingName(), @@ -595,10 +590,9 @@ public class StripePaymentService : IPaymentService subscriber.Gateway = GatewayType.Stripe; subscriber.GatewayCustomerId = customer.Id; } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Balance = customer.Balance - (long)(creditAmount * 100) - }); + + await _stripeAdapter.UpdateCustomerAsync(customer.Id, + new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) }); return !customerExists; } @@ -630,50 +624,57 @@ public class StripePaymentService : IPaymentService { var subscriptionInfo = new SubscriptionInfo(); - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customerGetOptions = new CustomerGetOptions(); - customerGetOptions.AddExpand("discount.coupon.applies_to"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); - - if (customer.Discount != null) - { - subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customer.Discount); - } - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { return subscriptionInfo; } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions + var subscription = await _stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.discount.coupon.applies_to", "discounts.coupon.applies_to", "test_clock"] }); + + if (subscription == null) { - Expand = ["test_clock"] - }); - - if (sub != null) - { - subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); - - var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub); - - if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) - { - subscriptionInfo.Subscription.SuspensionDate = suspensionDate; - subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; - } + return subscriptionInfo; } - if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(subscription); + + // Discount selection priority: + // 1. Customer-level discount (applies to all subscriptions for the customer) + // 2. First subscription-level discount (if multiple exist, FirstOrDefault() selects the first one) + // Note: When multiple subscription-level discounts exist, only the first one is used. + // This matches Stripe's behavior where the first discount in the list is applied. + // Defensive null checks: Even though we expand "customer" and "discounts", external APIs + // may not always return the expected data structure, so we use null-safe operators. + var discount = subscription.Customer?.Discount ?? subscription.Discounts?.FirstOrDefault(); + + if (discount != null) + { + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(discount); + } + + var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(subscription); + + if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) + { + subscriptionInfo.Subscription.SuspensionDate = suspensionDate; + subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; + } + + if (subscription is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { return subscriptionInfo; } try { - var upcomingInvoiceOptions = new UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }; - var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions); + var invoiceCreatePreviewOptions = new InvoiceCreatePreviewOptions + { + Customer = subscriber.GatewayCustomerId, + Subscription = subscriber.GatewaySubscriptionId + }; + + var upcomingInvoice = await _stripeAdapter.CreateInvoicePreviewAsync(invoiceCreatePreviewOptions); if (upcomingInvoice != null) { @@ -682,135 +683,17 @@ public class StripePaymentService : IPaymentService } catch (StripeException ex) { - _logger.LogWarning(ex, "Encountered an unexpected Stripe error"); + _logger.LogWarning( + ex, + "Failed to retrieve upcoming invoice for customer {CustomerId}, subscription {SubscriptionId}. Error Code: {ErrorCode}", + subscriber.GatewayCustomerId, + subscriber.GatewaySubscriptionId, + ex.StripeError?.Code); } return subscriptionInfo; } - public async Task GetTaxInfoAsync(ISubscriber subscriber) - { - if (subscriber == null || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - return null; - } - - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, - new CustomerGetOptions { Expand = ["tax_ids"] }); - - if (customer == null) - { - return null; - } - - var address = customer.Address; - var taxId = customer.TaxIds?.FirstOrDefault(); - - // Line1 is required, so if missing we're using the subscriber name, - // see: https://stripe.com/docs/api/customers/create#create_customer-address-line1 - if (address != null && string.IsNullOrWhiteSpace(address.Line1)) - { - address.Line1 = null; - } - - return new TaxInfo - { - TaxIdNumber = taxId?.Value, - TaxIdType = taxId?.Type, - BillingAddressLine1 = address?.Line1, - BillingAddressLine2 = address?.Line2, - BillingAddressCity = address?.City, - BillingAddressState = address?.State, - BillingAddressPostalCode = address?.PostalCode, - BillingAddressCountry = address?.Country, - }; - } - - public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) - { - if (string.IsNullOrWhiteSpace(subscriber?.GatewayCustomerId) || subscriber.IsUser()) - { - return; - } - - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions - { - Address = new AddressOptions - { - Line1 = taxInfo.BillingAddressLine1 ?? string.Empty, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState, - PostalCode = taxInfo.BillingAddressPostalCode, - Country = taxInfo.BillingAddressCountry, - }, - Expand = ["tax_ids"] - }); - - if (customer == null) - { - return; - } - - var taxId = customer.TaxIds?.FirstOrDefault(); - - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - - if (string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) - { - return; - } - - var taxIdType = taxInfo.TaxIdType; - - if (string.IsNullOrWhiteSpace(taxIdType)) - { - taxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); - - if (taxIdType == null) - { - _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - } - - try - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }); - - if (taxInfo.TaxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInfo.TaxIdNumber}" }); - } - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry); - throw new BadRequestException("billingInvalidTaxIdError"); - default: - _logger.LogError(e, - "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry, - customer.Id); - throw new BadRequestException("billingTaxIdCreationError"); - } - } - } - public async Task AddSecretsManagerToSubscription( Organization org, StaticStore.Plan plan, @@ -829,7 +712,8 @@ public class StripePaymentService : IPaymentService await HasSecretsManagerStandaloneAsync(gatewayCustomerId: organization.GatewayCustomerId, organizationHasSecretsManager: organization.UseSecretsManager); - private async Task HasSecretsManagerStandaloneAsync(string gatewayCustomerId, bool organizationHasSecretsManager) + private async Task HasSecretsManagerStandaloneAsync(string gatewayCustomerId, + bool organizationHasSecretsManager) { if (string.IsNullOrEmpty(gatewayCustomerId)) { @@ -841,7 +725,7 @@ public class StripePaymentService : IPaymentService return false; } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); return customer?.Discount?.Coupon?.Id == SecretsManagerStandaloneDiscountId; } @@ -853,7 +737,7 @@ public class StripePaymentService : IPaymentService return (null, null); } - var openInvoices = await _stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await _stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); @@ -887,303 +771,9 @@ public class StripePaymentService : IPaymentService } } - public async Task PreviewInvoiceAsync( - PreviewIndividualInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var options = new InvoiceCreatePreviewOptions - { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, - Currency = "usd", - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new() - { - Quantity = 1, - Plan = StripeConstants.Prices.PremiumAnnually - }, - - new() - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = "storage-gb-annually" - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - } - - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Coupon = gatewayCustomer.Discount.Coupon.Id; - } - } - - if (!string.IsNullOrWhiteSpace(gatewaySubscriptionId)) - { - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discount != null) - { - options.Coupon ??= gatewaySubscription.Discount.Coupon.Id; - } - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - - public async Task PreviewInvoiceAsync( - PreviewOrganizationInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var plan = await _pricingClient.GetPlanOrThrow(parameters.PasswordManager.Plan); - var isSponsored = parameters.PasswordManager.SponsoredPlan.HasValue; - - var options = new InvoiceCreatePreviewOptions - { - Currency = "usd", - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new() - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = plan.PasswordManager.StripeStoragePlanId - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (isSponsored) - { - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(parameters.PasswordManager.SponsoredPlan.Value); - options.SubscriptionDetails.Items.Add( - new() { Quantity = 1, Plan = sponsoredPlan.StripePlanId } - ); - } - else - { - if (plan.PasswordManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add( - new() { Quantity = parameters.PasswordManager.Seats, Plan = plan.PasswordManager.StripeSeatPlanId } - ); - } - else - { - options.SubscriptionDetails.Items.Add( - new() { Quantity = 1, Plan = plan.PasswordManager.StripePlanId } - ); - } - - if (plan.SupportsSecretsManager) - { - if (plan.SecretsManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add(new() - { - Quantity = parameters.SecretsManager?.Seats ?? 0, - Plan = plan.SecretsManager.StripeSeatPlanId - }); - } - - if (plan.SecretsManager.HasAdditionalServiceAccountOption) - { - options.SubscriptionDetails.Items.Add(new() - { - Quantity = parameters.SecretsManager?.AdditionalMachineAccounts ?? 0, - Plan = plan.SecretsManager.StripeServiceAccountPlanId - }); - } - } - } - - if (!string.IsNullOrWhiteSpace(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - Customer gatewayCustomer = null; - - if (!string.IsNullOrWhiteSpace(gatewayCustomerId)) - { - gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Coupon = gatewayCustomer.Discount.Coupon.Id; - } - } - - if (!string.IsNullOrWhiteSpace(gatewaySubscriptionId)) - { - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discount != null) - { - options.Coupon ??= gatewaySubscription.Discount.Coupon.Id; - } - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - if (parameters.PasswordManager.Plan.IsBusinessProductTierType() && - parameters.TaxInformation.Country != Constants.CountryAbbreviations.UnitedStates) - { - options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - private PaymentMethod GetLatestCardPaymentMethod(string customerId) { - var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( + var cardPaymentMethods = _stripeAdapter.ListPaymentMethodsAutoPaging( new PaymentMethodListOptions { Customer = customerId, Type = "card" }); return cardPaymentMethods.OrderByDescending(m => m.Created).FirstOrDefault(); } @@ -1207,7 +797,9 @@ public class StripePaymentService : IPaymentService braintreeCustomer.DefaultPaymentMethod); } } - catch (Braintree.Exceptions.NotFoundException) { } + catch (Braintree.Exceptions.NotFoundException) + { + } } if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") @@ -1244,14 +836,17 @@ public class StripePaymentService : IPaymentService Customer customer = null; try { - customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); + customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId, options); + } + catch (StripeException) + { } - catch (StripeException) { } return customer; } - private async Task> GetBillingTransactionsAsync(ISubscriber subscriber, int? limit = null) + private async Task> GetBillingTransactionsAsync( + ISubscriber subscriber, int? limit = null) { var transactions = subscriber switch { @@ -1274,21 +869,21 @@ public class StripePaymentService : IPaymentService try { - var paidInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var paidInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "paid" }); - var openInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var openInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, Limit = limit, Status = "open" }); - var uncollectibleInvoicesTask = _stripeAdapter.InvoiceListAsync(new StripeInvoiceListOptions + var uncollectibleInvoicesTask = _stripeAdapter.ListInvoicesAsync(new StripeInvoiceListOptions { Customer = customer.Id, SelectAll = !limit.HasValue, diff --git a/src/Core/Services/Implementations/StripeSyncService.cs b/src/Core/Billing/Services/Implementations/StripeSyncService.cs similarity index 68% rename from src/Core/Services/Implementations/StripeSyncService.cs rename to src/Core/Billing/Services/Implementations/StripeSyncService.cs index b2700e65d1..31dd89d72d 100644 --- a/src/Core/Services/Implementations/StripeSyncService.cs +++ b/src/Core/Billing/Services/Implementations/StripeSyncService.cs @@ -1,6 +1,6 @@ using Bit.Core.Exceptions; -namespace Bit.Core.Services; +namespace Bit.Core.Billing.Services.Implementations; public class StripeSyncService : IStripeSyncService { @@ -11,7 +11,7 @@ public class StripeSyncService : IStripeSyncService _stripeAdapter = stripeAdapter; } - public async Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress) + public async Task UpdateCustomerEmailAddressAsync(string gatewayCustomerId, string emailAddress) { if (string.IsNullOrWhiteSpace(gatewayCustomerId)) { @@ -23,9 +23,9 @@ public class StripeSyncService : IStripeSyncService throw new InvalidEmailException(); } - var customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + var customer = await _stripeAdapter.GetCustomerAsync(gatewayCustomerId); - await _stripeAdapter.CustomerUpdateAsync(customer.Id, + await _stripeAdapter.UpdateCustomerAsync(customer.Id, new Stripe.CustomerUpdateOptions { Email = emailAddress }); } } diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 8e75bf3dca..7acbe20014 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -15,7 +15,6 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -24,7 +23,6 @@ using Stripe; using static Bit.Core.Billing.Utilities; using Customer = Stripe.Customer; -using PaymentMethod = Bit.Core.Billing.Models.PaymentMethod; using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Services.Implementations; @@ -79,7 +77,7 @@ public class SubscriberService( { if (subscription.Metadata != null && subscription.Metadata.ContainsKey("organizationId")) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, new SubscriptionUpdateOptions + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { Metadata = metadata }); @@ -98,7 +96,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionCancelAsync(subscription.Id, options); + await stripeAdapter.CancelSubscriptionAsync(subscription.Id, options); } else { @@ -117,7 +115,7 @@ public class SubscriberService( options.CancellationDetails.Feedback = offboardingSurveyResponse.Reason; } - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, options); + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); } } @@ -228,7 +226,7 @@ public class SubscriberService( _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) }; - var customer = await stripeAdapter.CustomerCreateAsync(options); + var customer = await stripeAdapter.CreateCustomerAsync(options); switch (subscriber) { @@ -271,7 +269,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -307,7 +305,7 @@ public class SubscriberService( try { - var customer = await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); + var customer = await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId, customerGetOptions); if (customer != null) { @@ -330,38 +328,6 @@ public class SubscriberService( } } - public async Task GetPaymentMethod( - ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - var customer = await GetCustomer(subscriber, new CustomerGetOptions - { - Expand = ["default_source", "invoice_settings.default_payment_method", "subscriptions", "tax_ids"] - }); - - if (customer == null) - { - return PaymentMethod.Empty; - } - - var accountCredit = customer.Balance * -1 / 100M; - - var paymentMethod = await GetPaymentSourceAsync(subscriber.Id, customer); - - var subscriptionStatus = customer.Subscriptions - .FirstOrDefault(subscription => subscription.Id == subscriber.GatewaySubscriptionId)? - .Status; - - var taxInformation = GetTaxInformation(customer); - - return new PaymentMethod( - accountCredit, - paymentMethod, - subscriptionStatus, - taxInformation); - } - public async Task GetPaymentSource( ISubscriber subscriber) { @@ -390,7 +356,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -426,7 +392,7 @@ public class SubscriberService( try { - var subscription = await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); + var subscription = await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId, subscriptionGetOptions); if (subscription != null) { @@ -449,16 +415,6 @@ public class SubscriberService( } } - public async Task GetTaxInformation( - ISubscriber subscriber) - { - ArgumentNullException.ThrowIfNull(subscriber); - - var customer = await GetCustomerOrThrow(subscriber, new CustomerGetOptions { Expand = ["tax_ids"] }); - - return GetTaxInformation(customer); - } - public async Task RemovePaymentSource( ISubscriber subscriber) { @@ -530,23 +486,23 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(stripeCustomer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(stripeCustomer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(stripeCustomer.Id, source.Id); break; } } } - var paymentMethods = stripeAdapter.PaymentMethodListAutoPagingAsync(new PaymentMethodListOptions + var paymentMethods = stripeAdapter.ListPaymentMethodsAutoPagingAsync(new PaymentMethodListOptions { Customer = stripeCustomer.Id }); await foreach (var paymentMethod in paymentMethods) { - await stripeAdapter.PaymentMethodDetachAsync(paymentMethod.Id); + await stripeAdapter.DetachPaymentMethodAsync(paymentMethod.Id); } } } @@ -575,7 +531,7 @@ public class SubscriberService( { case PaymentMethodType.BankAccount: { - var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.SetupIntentList(new SetupIntentListOptions + var getSetupIntentsForUpdatedPaymentMethod = stripeAdapter.ListSetupIntentsAsync(new SetupIntentListOptions { PaymentMethod = token }); @@ -612,7 +568,7 @@ public class SubscriberService( await RemoveStripePaymentMethodsAsync(customer); // Attach the incoming payment method. - await stripeAdapter.PaymentMethodAttachAsync(token, + await stripeAdapter.AttachPaymentMethodAsync(token, new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); var metadata = customer.Metadata; @@ -624,7 +580,7 @@ public class SubscriberService( } // Set the customer's default payment method in Stripe and remove their Braintree customer ID. - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { InvoiceSettings = new CustomerInvoiceSettingsOptions { @@ -687,7 +643,7 @@ public class SubscriberService( Expand = ["subscriptions", "tax", "tax_ids"] }); - customer = await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + customer = await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Address = new AddressOptions { @@ -705,7 +661,7 @@ public class SubscriberService( if (taxId != null) { - await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + await stripeAdapter.DeleteTaxIdAsync(customer.Id, taxId.Id); } if (!string.IsNullOrWhiteSpace(taxInformation.TaxId)) @@ -728,12 +684,12 @@ public class SubscriberService( try { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = taxIdType, Value = taxInformation.TaxId }); if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) { - await stripeAdapter.TaxIdCreateAsync(customer.Id, + await stripeAdapter.CreateTaxIdAsync(customer.Id, new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInformation.TaxId}" }); } } @@ -779,7 +735,7 @@ public class SubscriberService( Address.Country: not Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: not TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); break; case @@ -787,14 +743,14 @@ public class SubscriberService( Address.Country: Core.Constants.CountryAbbreviations.UnitedStates, TaxExempt: TaxExempt.Reverse }: - await stripeAdapter.CustomerUpdateAsync(customer.Id, + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { TaxExempt = TaxExempt.None }); break; } if (!subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -814,7 +770,7 @@ public class SubscriberService( if (automaticTaxShouldBeEnabled && !subscription.AutomaticTax.Enabled) { - await stripeAdapter.SubscriptionUpdateAsync(subscription.Id, + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, new SubscriptionUpdateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } @@ -823,57 +779,6 @@ public class SubscriberService( } } - public async Task VerifyBankAccount( - ISubscriber subscriber, - string descriptorCode) - { - var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id); - - if (string.IsNullOrEmpty(setupIntentId)) - { - logger.LogError("No setup intent ID exists to verify for subscriber with ID ({SubscriberID})", subscriber.Id); - throw new BillingException(); - } - - try - { - await stripeAdapter.SetupIntentVerifyMicroDeposit(setupIntentId, - new SetupIntentVerifyMicrodepositsOptions { DescriptorCode = descriptorCode }); - - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId); - - await stripeAdapter.PaymentMethodAttachAsync(setupIntent.PaymentMethodId, - new PaymentMethodAttachOptions { Customer = subscriber.GatewayCustomerId }); - - await stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions - { - InvoiceSettings = new CustomerInvoiceSettingsOptions - { - DefaultPaymentMethod = setupIntent.PaymentMethodId - } - }); - } - catch (StripeException stripeException) - { - if (!string.IsNullOrEmpty(stripeException.StripeError?.Code)) - { - var message = stripeException.StripeError.Code switch - { - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationAttemptsExceeded => "You have exceeded the number of allowed verification attempts. Please contact support.", - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationDescriptorCodeMismatch => "The verification code you provided does not match the one sent to your bank account. Please try again.", - StripeConstants.ErrorCodes.PaymentMethodMicroDepositVerificationTimeout => "Your bank account was not verified within the required time period. Please contact support.", - _ => BillingException.DefaultMessage - }; - - throw new BadRequestException(message); - } - - logger.LogError(stripeException, "An unhandled Stripe exception was thrown while verifying subscriber's ({SubscriberID}) bank account", subscriber.Id); - throw new BillingException(); - } - } - public async Task IsValidGatewayCustomerIdAsync(ISubscriber subscriber) { ArgumentNullException.ThrowIfNull(subscriber); @@ -884,7 +789,7 @@ public class SubscriberService( } try { - await stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); + await stripeAdapter.GetCustomerAsync(subscriber.GatewayCustomerId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -903,7 +808,7 @@ public class SubscriberService( } try { - await stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId); + await stripeAdapter.GetSubscriptionAsync(subscriber.GatewaySubscriptionId); return true; } catch (StripeException e) when (e.StripeError.Code == "resource_missing") @@ -922,7 +827,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdKey] = braintreeCustomerId; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -962,7 +867,7 @@ public class SubscriberService( return null; } - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + var setupIntent = await stripeAdapter.GetSetupIntentAsync(setupIntentId, new SetupIntentGetOptions { Expand = ["payment_method"] }); @@ -970,25 +875,6 @@ public class SubscriberService( return PaymentSource.From(setupIntent); } - private static TaxInformation GetTaxInformation( - Customer customer) - { - if (customer.Address == null) - { - return null; - } - - return new TaxInformation( - customer.Address.Country, - customer.Address.PostalCode, - customer.TaxIds?.FirstOrDefault()?.Value, - customer.TaxIds?.FirstOrDefault()?.Type, - customer.Address.Line1, - customer.Address.Line2, - customer.Address.City, - customer.Address.State); - } - private async Task RemoveBraintreeCustomerIdAsync( Customer customer) { @@ -999,7 +885,7 @@ public class SubscriberService( metadata[BraintreeCustomerIdOldKey] = value; metadata[BraintreeCustomerIdKey] = null; - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions + await stripeAdapter.UpdateCustomerAsync(customer.Id, new CustomerUpdateOptions { Metadata = metadata }); @@ -1016,18 +902,18 @@ public class SubscriberService( switch (source) { case BankAccount: - await stripeAdapter.BankAccountDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteBankAccountAsync(customer.Id, source.Id); break; case Card: - await stripeAdapter.CardDeleteAsync(customer.Id, source.Id); + await stripeAdapter.DeleteCardAsync(customer.Id, source.Id); break; } } } - var paymentMethods = await stripeAdapter.CustomerListPaymentMethods(customer.Id); + var paymentMethods = await stripeAdapter.ListCustomerPaymentMethodsAsync(customer.Id); - await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.PaymentMethodDetachAsync(pm.Id))); + await Task.WhenAll(paymentMethods.Select(pm => stripeAdapter.DetachPaymentMethodAsync(pm.Id))); } private async Task ReplaceBraintreePaymentMethodAsync( diff --git a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs index 351c75ace0..165b8218a9 100644 --- a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs +++ b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs @@ -1,12 +1,13 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; +using Microsoft.Extensions.Logging; using OneOf.Types; using Stripe; @@ -21,14 +22,14 @@ public interface IRestartSubscriptionCommand } public class RestartSubscriptionCommand( + ILogger logger, IOrganizationRepository organizationRepository, - IProviderRepository providerRepository, + IPricingClient pricingClient, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService, - IUserRepository userRepository) : IRestartSubscriptionCommand + ISubscriberService subscriberService) : BaseBillingCommand(logger), IRestartSubscriptionCommand { - public async Task> Run( - ISubscriber subscriber) + public Task> Run( + ISubscriber subscriber) => HandleAsync(async () => { var existingSubscription = await subscriberService.GetSubscription(subscriber); @@ -37,56 +38,147 @@ public class RestartSubscriptionCommand( return new BadRequest("Cannot restart a subscription that is not canceled."); } + await RestartSubscriptionAsync(subscriber, existingSubscription); + + return new None(); + }); + + private Task RestartSubscriptionAsync( + ISubscriber subscriber, + Subscription canceledSubscription) => subscriber switch + { + Organization organization => RestartOrganizationSubscriptionAsync(organization, canceledSubscription), + _ => throw new NotSupportedException("Only organization subscriptions can be restarted") + }; + + private async Task RestartOrganizationSubscriptionAsync( + Organization organization, + Subscription canceledSubscription) + { + var plans = await pricingClient.ListPlans(); + + var oldPlan = plans.FirstOrDefault(plan => plan.Type == organization.PlanType); + + if (oldPlan == null) + { + throw new ConflictException("Could not find plan for organization's plan type"); + } + + var newPlan = oldPlan.Disabled + ? plans.FirstOrDefault(plan => + plan.ProductTier == oldPlan.ProductTier && + plan.IsAnnual == oldPlan.IsAnnual && + !plan.Disabled) + : oldPlan; + + if (newPlan == null) + { + throw new ConflictException("Could not find the current, enabled plan for organization's tier and cadence"); + } + + if (newPlan.Type != oldPlan.Type) + { + organization.PlanType = newPlan.Type; + organization.Plan = newPlan.Name; + organization.SelfHost = newPlan.HasSelfHost; + organization.UsePolicies = newPlan.HasPolicies; + organization.UseGroups = newPlan.HasGroups; + organization.UseDirectory = newPlan.HasDirectory; + organization.UseEvents = newPlan.HasEvents; + organization.UseTotp = newPlan.HasTotp; + organization.Use2fa = newPlan.Has2fa; + organization.UseApi = newPlan.HasApi; + organization.UseSso = newPlan.HasSso; + organization.UseOrganizationDomains = newPlan.HasOrganizationDomains; + organization.UseKeyConnector = newPlan.HasKeyConnector; + organization.UseScim = newPlan.HasScim; + organization.UseResetPassword = newPlan.HasResetPassword; + organization.UsersGetPremium = newPlan.UsersGetPremium; + organization.UseCustomPermissions = newPlan.HasCustomPermissions; + } + + var items = new List(); + + // Password Manager + var passwordManagerItem = canceledSubscription.Items.FirstOrDefault(item => + item.Price.Id == (oldPlan.HasNonSeatBasedPasswordManagerPlan() + ? oldPlan.PasswordManager.StripePlanId + : oldPlan.PasswordManager.StripeSeatPlanId)); + + if (passwordManagerItem == null) + { + throw new ConflictException("Organization's subscription does not have a Password Manager subscription item."); + } + + items.Add(new SubscriptionItemOptions + { + Price = newPlan.HasNonSeatBasedPasswordManagerPlan() ? newPlan.PasswordManager.StripePlanId : newPlan.PasswordManager.StripeSeatPlanId, + Quantity = passwordManagerItem.Quantity + }); + + // Storage + var storageItem = canceledSubscription.Items.FirstOrDefault( + item => item.Price.Id == oldPlan.PasswordManager.StripeStoragePlanId); + + if (storageItem != null) + { + items.Add(new SubscriptionItemOptions + { + Price = newPlan.PasswordManager.StripeStoragePlanId, + Quantity = storageItem.Quantity + }); + } + + // Secrets Manager & Service Accounts + var secretsManagerItem = oldPlan.SecretsManager != null + ? canceledSubscription.Items.FirstOrDefault(item => + item.Price.Id == oldPlan.SecretsManager.StripeSeatPlanId) + : null; + + var serviceAccountsItem = oldPlan.SecretsManager != null + ? canceledSubscription.Items.FirstOrDefault(item => + item.Price.Id == oldPlan.SecretsManager.StripeServiceAccountPlanId) + : null; + + if (newPlan.SecretsManager != null) + { + if (secretsManagerItem != null) + { + items.Add(new SubscriptionItemOptions + { + Price = newPlan.SecretsManager.StripeSeatPlanId, + Quantity = secretsManagerItem.Quantity + }); + } + + if (serviceAccountsItem != null) + { + items.Add(new SubscriptionItemOptions + { + Price = newPlan.SecretsManager.StripeServiceAccountPlanId, + Quantity = serviceAccountsItem.Quantity + }); + } + } + var options = new SubscriptionCreateOptions { AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, CollectionMethod = CollectionMethod.ChargeAutomatically, - Customer = existingSubscription.CustomerId, - Items = existingSubscription.Items.Select(subscriptionItem => new SubscriptionItemOptions - { - Price = subscriptionItem.Price.Id, - Quantity = subscriptionItem.Quantity - }).ToList(), - Metadata = existingSubscription.Metadata, + Customer = canceledSubscription.CustomerId, + Items = items, + Metadata = canceledSubscription.Metadata, OffSession = true, TrialPeriodDays = 0 }; - var subscription = await stripeAdapter.SubscriptionCreateAsync(options); - await EnableAsync(subscriber, subscription); - return new None(); - } + var subscription = await stripeAdapter.CreateSubscriptionAsync(options); - private async Task EnableAsync(ISubscriber subscriber, Subscription subscription) - { - switch (subscriber) - { - case Organization organization: - { - organization.GatewaySubscriptionId = subscription.Id; - organization.Enabled = true; - organization.ExpirationDate = subscription.CurrentPeriodEnd; - organization.RevisionDate = DateTime.UtcNow; - await organizationRepository.ReplaceAsync(organization); - break; - } - case Provider provider: - { - provider.GatewaySubscriptionId = subscription.Id; - provider.Enabled = true; - provider.RevisionDate = DateTime.UtcNow; - await providerRepository.ReplaceAsync(provider); - break; - } - case User user: - { - user.GatewaySubscriptionId = subscription.Id; - user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; - user.RevisionDate = DateTime.UtcNow; - await userRepository.ReplaceAsync(user); - break; - } - } + organization.GatewaySubscriptionId = subscription.Id; + organization.Enabled = true; + organization.ExpirationDate = subscription.GetCurrentPeriodEnd(); + organization.RevisionDate = DateTime.UtcNow; + + await organizationRepository.ReplaceAsync(organization); } } diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 2ee6b75664..ec5978988c 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -2,8 +2,8 @@ #nullable disable using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Services; using Stripe; namespace Bit.Core.Billing; @@ -22,7 +22,7 @@ public static class Utilities return null; } - var openInvoices = await stripeAdapter.InvoiceSearchAsync(new InvoiceSearchOptions + var openInvoices = await stripeAdapter.SearchInvoiceAsync(new InvoiceSearchOptions { Query = $"subscription:'{subscription.Id}' status:'open'" }); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 80b74877c5..c3c009a2d5 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -137,10 +137,18 @@ public static class FeatureFlagKeys /* Admin Console Team */ public const string PolicyRequirements = "pm-14439-policy-requirements"; public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast"; - public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations"; - public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions"; public const string CreateDefaultLocation = "pm-19467-create-default-location"; + public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; + public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration"; + public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud"; + public const string BulkRevokeUsersV2 = "pm-28456-bulk-revoke-users-v2"; + public const string PremiumAccessQuery = "pm-21411-premium-access-query"; + + /* Architecture */ + public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1"; + public const string DesktopMigrationMilestone2 = "desktop-ui-migration-milestone-2"; + public const string DesktopMigrationMilestone3 = "desktop-ui-migration-milestone-3"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; @@ -149,10 +157,14 @@ public static class FeatureFlagKeys public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor"; public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; public const string Otp6Digits = "pm-18612-otp-6-digits"; - public const string FailedTwoFactorEmail = "pm-24425-send-2fa-failed-email"; + public const string PM24579_PreventSsoOnExistingNonCompliantUsers = "pm-24579-prevent-sso-on-existing-non-compliant-users"; public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; public const string PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword = "pm-23174-manage-account-recovery-permission-drives-the-need-to-set-master-password"; + public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; + public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; + public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; + public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; /* Autofill Team */ public const string IdpAutoSubmitLogin = "idp-auto-submit-login"; @@ -160,6 +172,7 @@ public static class FeatureFlagKeys public const string InlineMenuFieldQualification = "inline-menu-field-qualification"; public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements"; public const string SSHAgent = "ssh-agent"; + public const string SSHAgentV2 = "ssh-agent-v2"; public const string SSHVersionCheckQAOverride = "ssh-version-check-qa-override"; public const string GenerateIdentityFillScriptRefactor = "generate-identity-fill-script-refactor"; public const string DelayFido2PageScriptInitWithinMv2 = "delay-fido2-page-script-init-within-mv2"; @@ -170,55 +183,52 @@ public static class FeatureFlagKeys public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public const string InlineMenuTotp = "inline-menu-totp"; public const string WindowsDesktopAutotype = "windows-desktop-autotype"; + public const string WindowsDesktopAutotypeGA = "windows-desktop-autotype-ga"; /* Billing Team */ - public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string TrialPayment = "PM-8163-trial-payment"; - public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships"; - public const string UsePricingService = "use-pricing-service"; - public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; - public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; - public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; + public const string PM25379_UseNewOrganizationMetadataStructure = "pm-25379-use-new-organization-metadata-structure"; public const string PM24996ImplementUpgradeFromFreeDialog = "pm-24996-implement-upgrade-from-free-dialog"; public const string PM24032_NewNavigationPremiumUpgradeButton = "pm-24032-new-navigation-premium-upgrade-button"; public const string PM23713_PremiumBadgeOpensNewPremiumUpgradeDialog = "pm-23713-premium-badge-opens-new-premium-upgrade-dialog"; + public const string PM26793_FetchPremiumPriceFromPricingService = "pm-26793-fetch-premium-price-from-pricing-service"; + public const string PM23341_Milestone_2 = "pm-23341-milestone-2"; + public const string PM26462_Milestone_3 = "pm-26462-milestone-3"; + public const string PM28265_EnableReconcileAdditionalStorageJob = "pm-28265-enable-reconcile-additional-storage-job"; + public const string PM28265_ReconcileAdditionalStorageJobEnableLiveMode = "pm-28265-reconcile-additional-storage-job-enable-live-mode"; /* Key Management Team */ - public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; - public const string PM4154BulkEncryptionService = "PM-4154-bulk-encryption-service"; public const string PrivateKeyRegeneration = "pm-12241-private-key-regeneration"; public const string Argon2Default = "argon2-default"; - public const string UserkeyRotationV2 = "userkey-rotation-v2"; public const string SSHKeyItemVaultItem = "ssh-key-vault-item"; - public const string UserSdkForDecryption = "use-sdk-for-decryption"; - public const string PM17987_BlockType0 = "pm-17987-block-type-0"; + public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation"; public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings"; public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data"; public const string WindowsBiometricsV2 = "pm-25373-windows-biometrics-v2"; + public const string LinuxBiometricsV2 = "pm-26340-linux-biometrics-v2"; public const string NoLogoutOnKdfChange = "pm-23995-no-logout-on-kdf-change"; + public const string DisableType0Decryption = "pm-25174-disable-type-0-decryption"; + public const string ConsolidatedSessionTimeoutComponent = "pm-26056-consolidated-session-timeout-component"; + public const string V2RegistrationTDEJIT = "pm-27279-v2-registration-tde-jit"; + public const string DataRecoveryTool = "pm-28813-data-recovery-tool"; + public const string EnableAccountEncryptionV2KeyConnectorRegistration = "enable-account-encryption-v2-key-connector-registration"; /* Mobile Team */ - public const string NativeCarouselFlow = "native-carousel-flow"; - public const string NativeCreateAccountFlow = "native-create-account-flow"; public const string AndroidImportLoginsFlow = "import-logins-flow"; - public const string AppReviewPrompt = "app-review-prompt"; public const string AndroidMutualTls = "mutual-tls"; public const string SingleTapPasskeyCreation = "single-tap-passkey-creation"; public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; - public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; - public const string EnablePMFlightRecorder = "enable-pm-flight-recorder"; public const string MobileErrorReporting = "mobile-error-reporting"; public const string AndroidChromeAutofill = "android-chrome-autofill"; public const string UserManagedPrivilegedApps = "pm-18970-user-managed-privileged-apps"; - public const string EnablePMPreloginSettings = "enable-pm-prelogin-settings"; - public const string AppIntents = "app-intents"; public const string SendAccess = "pm-19394-send-access-control"; public const string CxpImportMobile = "cxp-import-mobile"; public const string CxpExportMobile = "cxp-export-mobile"; /* Platform Team */ + public const string WebPush = "web-push"; public const string IpcChannelFramework = "ipc-channel-framework"; public const string PushNotificationsWhenLocked = "pm-19388-push-notifications-when-locked"; public const string PushNotificationsWhenInactive = "pm-25130-receive-push-notifications-for-inactive-users"; @@ -227,25 +237,32 @@ public static class FeatureFlagKeys public const string DesktopSendUIRefresh = "desktop-send-ui-refresh"; public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators"; public const string UseChromiumImporter = "pm-23982-chromium-importer"; + public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe"; + public const string SendUIRefresh = "pm-28175-send-ui-refresh"; + public const string SendEmailOTP = "pm-19051-send-email-verification"; /* Vault Team */ - public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge"; - public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; public const string CipherKeyEncryption = "cipher-key-encryption"; - public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms"; public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk"; - public const string EndUserNotifications = "pm-10609-end-user-notifications"; public const string PhishingDetection = "phishing-detection"; - public const string RemoveCardItemTypePolicy = "pm-16442-remove-card-item-type-policy"; public const string PM22134SdkCipherListView = "pm-22134-sdk-cipher-list-view"; - public const string PM19315EndUserActivationMvp = "pm-19315-end-user-activation-mvp"; public const string PM22136_SdkCipherEncryption = "pm-22136-sdk-cipher-encryption"; + public const string PM23904_RiskInsightsForPremium = "pm-23904-risk-insights-for-premium"; + public const string PM25083_AutofillConfirmFromSearch = "pm-25083-autofill-confirm-from-search"; + public const string VaultLoadingSkeletons = "pm-25081-vault-skeleton-loaders"; + public const string BrowserPremiumSpotlight = "pm-23384-browser-premium-spotlight"; + public const string MigrateMyVaultToMyItems = "pm-20558-migrate-myvault-to-myitems"; /* Innovation Team */ public const string ArchiveVaultItems = "pm-19148-innovation-archive"; /* DIRT Team */ public const string PM22887_RiskInsightsActivityTab = "pm-22887-risk-insights-activity-tab"; + public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike"; + public const string EventDiagnosticLogging = "pm-27666-siem-event-log-debugging"; + + /* UIF Team */ + public const string RouterFocusManagement = "router-focus-management"; public static List GetAllKeys() { diff --git a/src/Core/Context/CurrentContext.cs b/src/Core/Context/CurrentContext.cs index 5d9b5a1759..6067c60556 100644 --- a/src/Core/Context/CurrentContext.cs +++ b/src/Core/Context/CurrentContext.cs @@ -38,10 +38,6 @@ public class CurrentContext( public virtual List Providers { get; set; } public virtual Guid? InstallationId { get; set; } public virtual Guid? OrganizationId { get; set; } - public virtual bool CloudflareWorkerProxied { get; set; } - public virtual bool IsBot { get; set; } - public virtual bool MaybeBot { get; set; } - public virtual int? BotScore { get; set; } public virtual string ClientId { get; set; } public virtual Version ClientVersion { get; set; } public virtual bool ClientVersionIsPrerelease { get; set; } @@ -70,27 +66,6 @@ public class CurrentContext( DeviceType = dType; } - if (!BotScore.HasValue && httpContext.Request.Headers.TryGetValue("X-Cf-Bot-Score", out var cfBotScore) && - int.TryParse(cfBotScore, out var parsedBotScore)) - { - BotScore = parsedBotScore; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Worked-Proxied", out var cfWorkedProxied)) - { - CloudflareWorkerProxied = cfWorkedProxied == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Is-Bot", out var cfIsBot)) - { - IsBot = cfIsBot == "1"; - } - - if (httpContext.Request.Headers.TryGetValue("X-Cf-Maybe-Bot", out var cfMaybeBot)) - { - MaybeBot = cfMaybeBot == "1"; - } - if (httpContext.Request.Headers.TryGetValue("Bitwarden-Client-Version", out var bitWardenClientVersion) && Version.TryParse(bitWardenClientVersion, out var cVersion)) { ClientVersion = cVersion; diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index 417e220ba2..d527cdd363 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Security.Claims; +using System.Security.Claims; using Bit.Core.AdminConsole.Context; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Identity; @@ -12,6 +10,14 @@ using Microsoft.AspNetCore.Http; namespace Bit.Core.Context; +/// +/// Provides information about the current HTTP request and the currently authenticated user (if any). +/// This is often (but not exclusively) parsed from the JWT in the current request. +/// +/// +/// This interface suffers from having too much responsibility; consider whether any new code can go in a more +/// specific class rather than adding it here. +/// public interface ICurrentContext { HttpContext HttpContext { get; set; } @@ -25,9 +31,6 @@ public interface ICurrentContext Guid? InstallationId { get; set; } Guid? OrganizationId { get; set; } IdentityClientType IdentityClientType { get; set; } - bool IsBot { get; set; } - bool MaybeBot { get; set; } - int? BotScore { get; set; } string ClientId { get; set; } Version ClientVersion { get; set; } bool ClientVersionIsPrerelease { get; set; } @@ -59,8 +62,20 @@ public interface ICurrentContext Task EditSubscription(Guid orgId); Task EditPaymentMethods(Guid orgId); Task ViewBillingHistory(Guid orgId); + /// + /// Returns true if the current user is a member of a provider that manages the specified organization. + /// This generally gives the user administrative privileges for the organization. + /// + /// + /// Task ProviderUserForOrgAsync(Guid orgId); + /// + /// Returns true if the current user is a Provider Admin of the specified provider. + /// bool ProviderProviderAdmin(Guid providerId); + /// + /// Returns true if the current user is a member of the specified provider (with any role). + /// bool ProviderUser(Guid providerId); bool ProviderManageUsers(Guid providerId); bool ProviderAccessEventLogs(Guid providerId); diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index e9bf1b1807..52c0a641ab 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -16,19 +16,21 @@ - + + + - - - + + + - - - + + + @@ -38,6 +40,9 @@ + + + @@ -46,31 +51,30 @@ - - - - - - + + - - - - + + + + + + + - + diff --git a/src/Core/AdminConsole/Entities/Event.cs b/src/Core/Dirt/Entities/Event.cs similarity index 100% rename from src/Core/AdminConsole/Entities/Event.cs rename to src/Core/Dirt/Entities/Event.cs diff --git a/src/Core/AdminConsole/Entities/OrganizationIntegration.cs b/src/Core/Dirt/Entities/OrganizationIntegration.cs similarity index 80% rename from src/Core/AdminConsole/Entities/OrganizationIntegration.cs rename to src/Core/Dirt/Entities/OrganizationIntegration.cs index 86de25ce9a..42b4e89e27 100644 --- a/src/Core/AdminConsole/Entities/OrganizationIntegration.cs +++ b/src/Core/Dirt/Entities/OrganizationIntegration.cs @@ -1,10 +1,8 @@ -using Bit.Core.Entities; -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; +using Bit.Core.Entities; using Bit.Core.Utilities; -#nullable enable - -namespace Bit.Core.AdminConsole.Entities; +namespace Bit.Core.Dirt.Entities; public class OrganizationIntegration : ITableObject { diff --git a/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs b/src/Core/Dirt/Entities/OrganizationIntegrationConfiguration.cs similarity index 91% rename from src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs rename to src/Core/Dirt/Entities/OrganizationIntegrationConfiguration.cs index 52934cf7f3..2b8dbf9220 100644 --- a/src/Core/AdminConsole/Entities/OrganizationIntegrationConfiguration.cs +++ b/src/Core/Dirt/Entities/OrganizationIntegrationConfiguration.cs @@ -2,9 +2,7 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -#nullable enable - -namespace Bit.Core.AdminConsole.Entities; +namespace Bit.Core.Dirt.Entities; public class OrganizationIntegrationConfiguration : ITableObject { diff --git a/src/Core/Dirt/Entities/OrganizationReport.cs b/src/Core/Dirt/Entities/OrganizationReport.cs index a776648b35..9d04180c8d 100644 --- a/src/Core/Dirt/Entities/OrganizationReport.cs +++ b/src/Core/Dirt/Entities/OrganizationReport.cs @@ -11,12 +11,24 @@ public class OrganizationReport : ITableObject public Guid OrganizationId { get; set; } public string ReportData { get; set; } = string.Empty; public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public string ContentEncryptionKey { get; set; } = string.Empty; - - public string? SummaryData { get; set; } = null; - public string? ApplicationData { get; set; } = null; + public string? SummaryData { get; set; } + public string? ApplicationData { get; set; } public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public int? ApplicationCount { get; set; } + public int? ApplicationAtRiskCount { get; set; } + public int? CriticalApplicationCount { get; set; } + public int? CriticalApplicationAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public int? MemberAtRiskCount { get; set; } + public int? CriticalMemberCount { get; set; } + public int? CriticalMemberAtRiskCount { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? CriticalPasswordCount { get; set; } + public int? CriticalPasswordAtRiskCount { get; set; } + + public void SetNewId() { diff --git a/src/Core/AdminConsole/Enums/EventSystemUser.cs b/src/Core/Dirt/Enums/EventSystemUser.cs similarity index 100% rename from src/Core/AdminConsole/Enums/EventSystemUser.cs rename to src/Core/Dirt/Enums/EventSystemUser.cs diff --git a/src/Core/AdminConsole/Enums/EventType.cs b/src/Core/Dirt/Enums/EventType.cs similarity index 96% rename from src/Core/AdminConsole/Enums/EventType.cs rename to src/Core/Dirt/Enums/EventType.cs index 8073938fc5..916f408fe6 100644 --- a/src/Core/AdminConsole/Enums/EventType.cs +++ b/src/Core/Dirt/Enums/EventType.cs @@ -60,6 +60,7 @@ public enum EventType : int OrganizationUser_RejectedAuthRequest = 1514, OrganizationUser_Deleted = 1515, // Both user and organization user data were deleted OrganizationUser_Left = 1516, // User voluntarily left the organization + OrganizationUser_AutomaticallyConfirmed = 1517, Organization_Updated = 1600, Organization_PurgedVault = 1601, @@ -80,6 +81,8 @@ public enum EventType : int Organization_CollectionManagement_LimitItemDeletionDisabled = 1615, Organization_CollectionManagement_AllowAdminAccessToAllCollectionItemsEnabled = 1616, Organization_CollectionManagement_AllowAdminAccessToAllCollectionItemsDisabled = 1617, + Organization_ItemOrganization_Accepted = 1618, + Organization_ItemOrganization_Declined = 1619, Policy_Updated = 1700, diff --git a/src/Core/AdminConsole/Enums/IntegrationType.cs b/src/Core/Dirt/Enums/IntegrationType.cs similarity index 83% rename from src/Core/AdminConsole/Enums/IntegrationType.cs rename to src/Core/Dirt/Enums/IntegrationType.cs index 34edc71fbe..767f2feb06 100644 --- a/src/Core/AdminConsole/Enums/IntegrationType.cs +++ b/src/Core/Dirt/Enums/IntegrationType.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.Enums; +namespace Bit.Core.Dirt.Enums; public enum IntegrationType : int { @@ -7,7 +7,8 @@ public enum IntegrationType : int Slack = 3, Webhook = 4, Hec = 5, - Datadog = 6 + Datadog = 6, + Teams = 7 } public static class IntegrationTypeExtensions @@ -24,6 +25,8 @@ public static class IntegrationTypeExtensions return "hec"; case IntegrationType.Datadog: return "datadog"; + case IntegrationType.Teams: + return "teams"; default: throw new ArgumentOutOfRangeException(nameof(type), $"Unsupported integration type: {type}"); } diff --git a/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs b/src/Core/Dirt/Enums/OrganizationIntegrationStatus.cs similarity index 66% rename from src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs rename to src/Core/Dirt/Enums/OrganizationIntegrationStatus.cs index 78a7bc6d63..aad0530971 100644 --- a/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs +++ b/src/Core/Dirt/Enums/OrganizationIntegrationStatus.cs @@ -1,4 +1,4 @@ -namespace Bit.Api.AdminConsole.Models.Response.Organizations; +namespace Bit.Core.Dirt.Enums; public enum OrganizationIntegrationStatus : int { diff --git a/src/Core/Dirt/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs b/src/Core/Dirt/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs new file mode 100644 index 0000000000..b03a68cfa6 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/EventIntegrationsServiceCollectionExtensions.cs @@ -0,0 +1,569 @@ +using Azure.Messaging.ServiceBus; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Teams; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Dirt.Services.NoopImplementations; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; +using TableStorageRepos = Bit.Core.Repositories.TableStorage; + +namespace Microsoft.Extensions.DependencyInjection; + +public static class EventIntegrationsServiceCollectionExtensions +{ + /// + /// Adds all event integrations commands, queries, and required cache infrastructure. + /// This method is idempotent and can be called multiple times safely. + /// + public static IServiceCollection AddEventIntegrationsCommandsQueries( + this IServiceCollection services, + GlobalSettings globalSettings) + { + // Ensure cache is registered first - commands depend on this keyed cache. + // This is idempotent for the same named cache, so it's safe to call. + services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + + // Add Validator + services.TryAddSingleton(); + + // Add all commands/queries + services.AddOrganizationIntegrationCommandsQueries(); + services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + return services; + } + + /// + /// Registers event write services based on available configuration. + /// + /// The service collection to add services to. + /// The global settings containing event logging configuration. + /// The service collection for chaining. + /// + /// + /// This method registers the appropriate IEventWriteService implementation based on the available + /// configuration, checking in the following priority order: + /// + /// + /// 1. Azure Service Bus - If all Azure Service Bus settings are present, registers + /// EventIntegrationEventWriteService with AzureServiceBusService as the publisher + /// + /// + /// 2. RabbitMQ - If all RabbitMQ settings are present, registers EventIntegrationEventWriteService with + /// RabbitMqService as the publisher + /// + /// + /// 3. Azure Queue Storage - If Events.ConnectionString is present, registers AzureQueueEventWriteService + /// + /// + /// 4. Repository (Self-Hosted) - If SelfHosted is true, registers RepositoryEventWriteService + /// + /// + /// 5. Noop - If none of the above are configured, registers NoopEventWriteService (no-op implementation) + /// + /// + public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) + { + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString) && + CoreHelpers.SettingHasValue(globalSettings.Events.QueueName)) + { + services.TryAddSingleton(); + return services; + } + + if (globalSettings.SelfHosted) + { + services.TryAddSingleton(); + return services; + } + + services.TryAddSingleton(); + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing Azure Service Bus configuration. + /// The service collection for chaining. + /// + /// + /// If Azure Service Bus is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When Azure Service Bus is enabled, this method registers: + /// - IAzureServiceBusService and IEventIntegrationPublisher implementations + /// - Table Storage event repository + /// - Azure Table Storage event handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsAzureServiceBusEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing RabbitMQ configuration. + /// The service collection for chaining. + /// + /// + /// If RabbitMQ is not enabled (missing required settings), this method returns immediately + /// without registering any services. + /// + /// + /// When RabbitMQ is enabled, this method registers: + /// - IRabbitMqService and IEventIntegrationPublisher implementations + /// - Event repository handler + /// - All event integration services via AddEventIntegrationServices + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before this method, + /// as it is required to create the event integrations extended cache. + /// + /// + public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) + { + if (!IsRabbitMqEnabled(globalSettings)) + { + return services; + } + + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + + services.AddEventIntegrationServices(globalSettings); + + return services; + } + + /// + /// Registers Slack integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Slack configuration. + /// The service collection for chaining. + /// + /// If all required Slack settings are configured (ClientId, ClientSecret, Scopes), registers the full SlackService, + /// including an HttpClient for Slack API calls. Otherwise, registers a NoopSlackService that performs no operations. + /// + public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) + { + services.AddHttpClient(SlackService.HttpClientName); + services.TryAddSingleton(); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers Microsoft Teams integration services based on configuration settings. + /// + /// The service collection to add services to. + /// The global settings containing Teams configuration. + /// The service collection for chaining. + /// + /// If all required Teams settings are configured (ClientId, ClientSecret, Scopes), registers: + /// - TeamsService and its interfaces (IBot, ITeamsService) + /// - IBotFrameworkHttpAdapter with Teams credentials + /// - HttpClient for Teams API calls + /// Otherwise, registers a NoopTeamsService that performs no operations. + /// + public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) + { + services.AddHttpClient(TeamsService.HttpClientName); + services.TryAddSingleton(); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(_ => + new BotFrameworkHttpAdapter( + new TeamsBotCredentialProvider( + clientId: globalSettings.Teams.ClientId, + clientSecret: globalSettings.Teams.ClientSecret + ) + ) + ); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + + /// + /// Registers event integration services including handlers, listeners, and supporting infrastructure. + /// + /// The service collection to add services to. + /// The global settings containing integration configuration. + /// The service collection for chaining. + /// + /// + /// This method orchestrates the registration of all event integration components based on the enabled + /// message broker (Azure Service Bus or RabbitMQ). It is an internal method called by the public + /// entry points AddAzureServiceBusListeners and AddRabbitMqListeners. + /// + /// + /// NOTE: If both Azure Service Bus and RabbitMQ are configured, Azure Service Bus takes precedence. This means that + /// Azure Service Bus listeners will be registered (and RabbitMQ listeners will NOT) even if this event is called + /// from AddRabbitMqListeners when Azure Service Bus settings are configured. + /// + /// + /// PREREQUISITE: Callers must ensure AddDistributedCache has been called before invoking this method. + /// This method depends on distributed cache infrastructure being available for the keyed extended + /// cache registration. + /// + /// + /// Registered Services: + /// - Keyed ExtendedCache for event integrations + /// - Integration filter service + /// - Integration handlers for Slack, Webhook, Hec, Datadog, and Teams + /// - Hosted services for event and integration listeners (based on enabled message broker) + /// + /// + internal static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, + GlobalSettings globalSettings) + { + // Add common services + // NOTE: AddDistributedCache must be called by the caller before this method + services.AddExtendedCache(EventIntegrationsCacheConstants.CacheName, globalSettings); + services.TryAddSingleton(); + services.TryAddKeyedSingleton("persistent"); + + // Add services in support of handlers + services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); + services.TryAddSingleton(TimeProvider.System); + services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); + services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); + + // Add integration handlers + services.TryAddSingleton, SlackIntegrationHandler>(); + services.TryAddSingleton, WebhookIntegrationHandler>(); + services.TryAddSingleton, DatadogIntegrationHandler>(); + services.TryAddSingleton, TeamsIntegrationHandler>(); + + var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); + var slackConfiguration = new SlackListenerConfiguration(globalSettings); + var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); + var hecConfiguration = new HecListenerConfiguration(globalSettings); + var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); + var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); + + if (IsAzureServiceBusEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: repositoryConfiguration, + handler: provider.GetRequiredService(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = repositoryConfiguration.EventPrefetchCount, + MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddAzureServiceBusIntegration(slackConfiguration); + services.AddAzureServiceBusIntegration(webhookConfiguration); + services.AddAzureServiceBusIntegration(hecConfiguration); + services.AddAzureServiceBusIntegration(datadogConfiguration); + services.AddAzureServiceBusIntegration(teamsConfiguration); + + return services; + } + + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredService(), + configuration: repositoryConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.AddRabbitMqIntegration(slackConfiguration); + services.AddRabbitMqIntegration(webhookConfiguration); + services.AddRabbitMqIntegration(hecConfiguration); + services.AddRabbitMqIntegration(datadogConfiguration); + services.AddRabbitMqIntegration(teamsConfiguration); + } + + return services; + } + + /// + /// Registers Azure Service Bus-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. AzureServiceBusEventListenerService - Hosted service for listening to event messages from Azure Service Bus + /// for this integration type + /// 3. AzureServiceBusIntegrationListenerService - Hosted service for listening to integration messages from + /// Azure Service Bus for this integration type + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + /// Service Bus processor options (PrefetchCount and MaxConcurrentCalls) are configured from the listener + /// configuration to optimize message throughput and concurrency. + /// + /// + internal static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusEventListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.EventPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new AzureServiceBusIntegrationListenerService( + configuration: listenerConfiguration, + handler: provider.GetRequiredService>(), + serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls + }, + loggerFactory: provider.GetRequiredService() + ) + ) + ); + + return services; + } + + /// + /// Registers RabbitMQ-based event integration listeners for a specific integration type. + /// + /// The integration configuration details type (e.g., SlackIntegrationConfigurationDetails). + /// The listener configuration type implementing IIntegrationListenerConfiguration. + /// The service collection to add services to. + /// The listener configuration containing routing keys and message processing settings. + /// The service collection for chaining. + /// + /// + /// This method registers three key components: + /// 1. EventIntegrationHandler - Keyed singleton for processing integration events + /// 2. RabbitMqEventListenerService - Hosted service for listening to event messages from RabbitMQ for + /// this integration type + /// 3. RabbitMqIntegrationListenerService - Hosted service for listening to integration messages from RabbitMQ for + /// this integration type + /// + /// + /// + /// The handler uses the listener configuration's routing key as its service key, allowing multiple + /// handlers to be registered for different integration types. + /// + /// + internal static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, + TListenerConfig listenerConfiguration) + where TConfig : class + where TListenerConfig : IIntegrationListenerConfiguration + { + services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => + new EventIntegrationHandler( + integrationType: listenerConfiguration.IntegrationType, + eventIntegrationPublisher: provider.GetRequiredService(), + integrationFilterService: provider.GetRequiredService(), + cache: provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName), + configurationRepository: provider.GetRequiredService(), + groupRepository: provider.GetRequiredService(), + organizationRepository: provider.GetRequiredService(), + organizationUserRepository: provider.GetRequiredService(), logger: provider.GetRequiredService>>()) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqEventListenerService( + handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService() + ) + ) + ); + services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => + new RabbitMqIntegrationListenerService( + handler: provider.GetRequiredService>(), + configuration: listenerConfiguration, + rabbitMqService: provider.GetRequiredService(), + loggerFactory: provider.GetRequiredService(), + timeProvider: provider.GetRequiredService() + ) + ) + ); + + return services; + } + + internal static IServiceCollection AddOrganizationIntegrationCommandsQueries(this IServiceCollection services) + { + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + + return services; + } + + internal static IServiceCollection AddOrganizationIntegrationConfigurationCommandsQueries(this IServiceCollection services) + { + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + services.TryAddScoped(); + + return services; + } + + /// + /// Determines if RabbitMQ is enabled for event integrations based on configuration settings. + /// + /// The global settings containing RabbitMQ configuration. + /// True if all required RabbitMQ settings are present; otherwise, false. + /// + /// Requires all the following settings to be configured: + /// + /// EventLogging.RabbitMq.HostName + /// EventLogging.RabbitMq.Username + /// EventLogging.RabbitMq.Password + /// EventLogging.RabbitMq.EventExchangeName + /// EventLogging.RabbitMq.IntegrationExchangeName + /// + /// + internal static bool IsRabbitMqEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName) && + CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.IntegrationExchangeName); + } + + /// + /// Determines if Azure Service Bus is enabled for event integrations based on configuration settings. + /// + /// The global settings containing Azure Service Bus configuration. + /// True if all required Azure Service Bus settings are present; otherwise, false. + /// + /// Requires all of the following settings to be configured: + /// + /// EventLogging.AzureServiceBus.ConnectionString + /// EventLogging.AzureServiceBus.EventTopicName + /// EventLogging.AzureServiceBus.IntegrationTopicName + /// + /// + internal static bool IsAzureServiceBusEnabled(GlobalSettings settings) + { + return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && + CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName) && + CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.IntegrationTopicName); + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..478b43bb7e --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,64 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for creating organization integration configurations with validation and cache invalidation support. +/// +public class CreateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : ICreateOrganizationIntegrationConfigurationCommand +{ + public async Task CreateAsync( + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, configuration)) + { + throw new BadRequestException( + $"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + var created = await configurationRepository.CreateAsync(configuration); + + // Invalidate the cached configuration details + // Even though this is a new record, the cache could hold a stale empty list for this + if (created.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: created.EventType.Value + )); + } + + return created; + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..d6369f1b1b --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,54 @@ +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for deleting organization integration configurations with cache invalidation support. +/// +public class DeleteOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache) + : IDeleteOrganizationIntegrationConfigurationCommand +{ + public async Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + + await configurationRepository.DeleteAsync(configuration); + + if (configuration.EventType == null) + { + // Wildcard configuration - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } + else + { + // Specific event type - only invalidate that specific cache entry + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + } + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..6dfe2949a4 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,29 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Query implementation for retrieving organization integration configurations. +/// +public class GetOrganizationIntegrationConfigurationsQuery( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository) + : IGetOrganizationIntegrationConfigurationsQuery +{ + public async Task> GetManyByIntegrationAsync( + Guid organizationId, + Guid integrationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + + var configurations = await configurationRepository.GetManyByIntegrationAsync(integrationId); + return configurations.ToList(); + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..629a1ee8ed --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/ICreateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,22 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for creating organization integration configurations. +/// +public interface ICreateOrganizationIntegrationConfigurationCommand +{ + /// + /// Creates a new configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The configuration to create. + /// The created configuration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task CreateAsync(Guid organizationId, Guid integrationId, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..d6866443c2 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IDeleteOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,19 @@ +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for deleting organization integration configurations. +/// +public interface IDeleteOrganizationIntegrationConfigurationCommand +{ + /// + /// Deletes a configuration from an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to delete. + /// + /// Thrown when the integration or configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + Task DeleteAsync(Guid organizationId, Guid integrationId, Guid configurationId); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs new file mode 100644 index 0000000000..a6635cb3be --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IGetOrganizationIntegrationConfigurationsQuery.cs @@ -0,0 +1,19 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Query interface for retrieving organization integration configurations. +/// +public interface IGetOrganizationIntegrationConfigurationsQuery +{ + /// + /// Retrieves all configurations for a specific organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// A list of configurations associated with the integration. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + Task> GetManyByIntegrationAsync(Guid organizationId, Guid integrationId); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..3ed680b808 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/Interfaces/IUpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,25 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; + +/// +/// Command interface for updating organization integration configurations. +/// +public interface IUpdateOrganizationIntegrationConfigurationCommand +{ + /// + /// Updates an existing configuration for an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration. + /// The unique identifier of the configuration to update. + /// The updated configuration data. + /// The updated configuration. + /// + /// Thrown when the integration or the configuration does not exist, + /// or the integration does not belong to the specified organization, + /// or the configuration does not belong to the specified integration. + /// Thrown when the configuration or filters + /// are invalid for the integration type. + Task UpdateAsync(Guid organizationId, Guid integrationId, Guid configurationId, OrganizationIntegrationConfiguration updatedConfiguration); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs new file mode 100644 index 0000000000..69c28f3e7e --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommand.cs @@ -0,0 +1,82 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +/// +/// Command implementation for updating organization integration configurations with validation and cache invalidation support. +/// +public class UpdateOrganizationIntegrationConfigurationCommand( + IOrganizationIntegrationRepository integrationRepository, + IOrganizationIntegrationConfigurationRepository configurationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache, + IOrganizationIntegrationConfigurationValidator validator) + : IUpdateOrganizationIntegrationConfigurationCommand +{ + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration == null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + var configuration = await configurationRepository.GetByIdAsync(configurationId); + if (configuration is null || configuration.OrganizationIntegrationId != integrationId) + { + throw new NotFoundException(); + } + if (!validator.ValidateConfiguration(integration.Type, updatedConfiguration)) + { + throw new BadRequestException($"Invalid Configuration and/or Filters for integration type {integration.Type}"); + } + + updatedConfiguration.Id = configuration.Id; + updatedConfiguration.CreationDate = configuration.CreationDate; + await configurationRepository.ReplaceAsync(updatedConfiguration); + + // If either old or new EventType is null (wildcard), invalidate all cached results + // for the specific integration + if (configuration.EventType == null || updatedConfiguration.EventType == null) + { + // Wildcard involved - invalidate all cached results for this org/integration + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + + return updatedConfiguration; + } + + // Both are specific event types - invalidate specific cache entries + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: configuration.EventType.Value + )); + + // If event type changed, also clear the new event type's cache + if (configuration.EventType != updatedConfiguration.EventType) + { + await cache.RemoveAsync( + EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integration.Type, + eventType: updatedConfiguration.EventType.Value + )); + } + + return updatedConfiguration; + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..4423c103f9 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommand.cs @@ -0,0 +1,38 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for creating organization integrations with cache invalidation support. +/// +public class CreateOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] + IFusionCache cache) + : ICreateOrganizationIntegrationCommand +{ + public async Task CreateAsync(OrganizationIntegration integration) + { + var existingIntegrations = await integrationRepository + .GetManyByOrganizationAsync(integration.OrganizationId); + if (existingIntegrations.Any(i => i.Type == integration.Type)) + { + throw new BadRequestException("An integration of this type already exists for this organization."); + } + + var created = await integrationRepository.CreateAsync(integration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: integration.OrganizationId, + integrationType: integration.Type + )); + + return created; + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..dc1e7fb1dc --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommand.cs @@ -0,0 +1,33 @@ +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for deleting organization integrations with cache invalidation support. +/// +public class DeleteOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] IFusionCache cache) + : IDeleteOrganizationIntegrationCommand +{ + public async Task DeleteAsync(Guid organizationId, Guid integrationId) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration is null || integration.OrganizationId != organizationId) + { + throw new NotFoundException(); + } + + await integrationRepository.DeleteAsync(integration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs new file mode 100644 index 0000000000..807f0b0b59 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQuery.cs @@ -0,0 +1,18 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Repositories; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; + +/// +/// Query implementation for retrieving organization integrations. +/// +public class GetOrganizationIntegrationsQuery(IOrganizationIntegrationRepository integrationRepository) + : IGetOrganizationIntegrationsQuery +{ + public async Task> GetManyByOrganizationAsync(Guid organizationId) + { + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + return integrations.ToList(); + } +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..0b06d79bdb --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/ICreateOrganizationIntegrationCommand.cs @@ -0,0 +1,18 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for creating an OrganizationIntegration. +/// +public interface ICreateOrganizationIntegrationCommand +{ + /// + /// Creates a new organization integration. + /// + /// The OrganizationIntegration to create. + /// The created OrganizationIntegration. + /// Thrown when an integration + /// of the same type already exists for the organization. + Task CreateAsync(OrganizationIntegration integration); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..8640f03ec8 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IDeleteOrganizationIntegrationCommand.cs @@ -0,0 +1,16 @@ +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for deleting organization integrations. +/// +public interface IDeleteOrganizationIntegrationCommand +{ + /// + /// Deletes an organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration to delete. + /// Thrown when the integration does not exist + /// or does not belong to the specified organization. + Task DeleteAsync(Guid organizationId, Guid integrationId); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs new file mode 100644 index 0000000000..1f378abe9b --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IGetOrganizationIntegrationsQuery.cs @@ -0,0 +1,16 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Query interface for retrieving organization integrations. +/// +public interface IGetOrganizationIntegrationsQuery +{ + /// + /// Retrieves all organization integrations for a specific organization. + /// + /// The unique identifier of the organization. + /// A list of organization integrations associated with the organization. + Task> GetManyByOrganizationAsync(Guid organizationId); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..ddba2bd233 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/Interfaces/IUpdateOrganizationIntegrationCommand.cs @@ -0,0 +1,20 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; + +/// +/// Command interface for updating organization integrations. +/// +public interface IUpdateOrganizationIntegrationCommand +{ + /// + /// Updates an existing organization integration. + /// + /// The unique identifier of the organization. + /// The unique identifier of the integration to update. + /// The updated organization integration data. + /// The updated organization integration. + /// Thrown when the integration does not exist, + /// does not belong to the specified organization, or the integration type does not match. + Task UpdateAsync(Guid organizationId, Guid integrationId, OrganizationIntegration updatedIntegration); +} diff --git a/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs new file mode 100644 index 0000000000..77a3448276 --- /dev/null +++ b/src/Core/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommand.cs @@ -0,0 +1,45 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Microsoft.Extensions.DependencyInjection; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; + +/// +/// Command implementation for updating organization integrations with cache invalidation support. +/// +public class UpdateOrganizationIntegrationCommand( + IOrganizationIntegrationRepository integrationRepository, + [FromKeyedServices(EventIntegrationsCacheConstants.CacheName)] + IFusionCache cache) + : IUpdateOrganizationIntegrationCommand +{ + public async Task UpdateAsync( + Guid organizationId, + Guid integrationId, + OrganizationIntegration updatedIntegration) + { + var integration = await integrationRepository.GetByIdAsync(integrationId); + if (integration is null || + integration.OrganizationId != organizationId || + integration.Type != updatedIntegration.Type) + { + throw new NotFoundException(); + } + + updatedIntegration.Id = integration.Id; + updatedIntegration.OrganizationId = integration.OrganizationId; + updatedIntegration.CreationDate = integration.CreationDate; + await integrationRepository.ReplaceAsync(updatedIntegration); + await cache.RemoveByTagAsync( + EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId: organizationId, + integrationType: integration.Type + )); + + return updatedIntegration; + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md b/src/Core/Dirt/EventIntegrations/README.md similarity index 74% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md rename to src/Core/Dirt/EventIntegrations/README.md index de7ce3f7fd..f9de5b9778 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md +++ b/src/Core/Dirt/EventIntegrations/README.md @@ -203,31 +203,17 @@ Currently, there are integrations / handlers for Slack, webhooks, and HTTP Event - The top-level object that enables a specific integration for the organization. - Includes any properties that apply to the entire integration across all events. - - For Slack, it consists of the token: `{ "Token": "xoxb-token-from-slack" }`. - - For webhooks, it is optional. Webhooks can either be configured at this level or the configuration level, - but the configuration level takes precedence. However, even though it is optional, an organization must - have a webhook `OrganizationIntegration` (even will a `null` `Configuration`) to enable configuration - via `OrganizationIntegrationConfiguration`. - - For HEC, it consists of the scheme, token, and URI: - -```json - { - "Scheme": "Bearer", - "Token": "Auth-token-from-HEC-service", - "Uri": "https://example.com/api" - } -``` + - For example, Slack stores the token in the `Configuration` which applies to every event, but stores the +channel id in the `Configuration` of the `OrganizationIntegrationConfiguration`. The token applies to the entire Slack +integration, but the channel could be configured differently depending on event type. + - See the table below for more examples / details on what is stored at which level. ### `OrganizationIntegrationConfiguration` - This contains the configurations specific to each `EventType` for the integration. - `Configuration` contains the event-specific configuration. - - For Slack, this would contain what channel to send the message to: `{ "channelId": "C123456" }` - - For webhooks, this is the URL the request should be sent to: `{ "url": "https://api.example.com" }` - - Optionally this also can include a `Scheme` and `Token` if this webhook needs Authentication. - - As stated above, all of this information can be specified here or at the `OrganizationIntegration` - level, but any properties declared here will take precedence over the ones above. - - For HEC, this must be null. HEC is configured only at the `OrganizationIntegration` level. + - Any properties at this level override the `Configuration` form the `OrganizationIntegration`. + - See the table below for examples of specific integrations. - `Template` contains a template string that is expected to be filled in with the contents of the actual event. - The tokens in the string are wrapped in `#` characters. For instance, the UserId would be `#UserId#`. - The `IntegrationTemplateProcessor` does the actual work of replacing these tokens with introspected values from @@ -245,6 +231,23 @@ Currently, there are integrations / handlers for Slack, webhooks, and HTTP Event - An array of `OrganizationIntegrationConfigurationDetails` is what the `EventIntegrationHandler` fetches from the database to determine what to publish at the integration level. +### Existing integrations and the configurations at each level + +The following table illustrates how each integration is configured and what exactly is stored in the `Configuration` +property at each level (`OrganizationIntegration` or `OrganizationIntegrationConfiguration`). Under +`OrganizationIntegration` the valid `OrganizationIntegrationStatus` are in bold, with an example of what would be +stored at each status. + +| **Integration** | **OrganizationIntegration** | **OrganizationIntegrationConfiguration** | +|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------| +| CloudBillingSync | **Not Applicable** (not yet used) | **Not Applicable** (not yet used) | +| Scim | **Not Applicable** (not yet used) | **Not Applicable** (not yet used) | +| Slack | **Initiated**: `null`
    **Completed**:
    `{ "Token": "xoxb-token-from-slack" }` | `{ "channelId": "C123456" }` | +| Webhook | `null` or `{ "Scheme": "Bearer", "Token": "AUTH-TOKEN", "Uri": "https://example.com" }` | `null` or `{ "Scheme": "Bearer", "Token":"AUTH-TOKEN", "Uri": "https://example.com" }`

    Whatever is defined at this level takes precedence | +| Hec | `{ "Scheme": "Bearer", "Token": "AUTH-TOKEN", "Uri": "https://example.com" }` | Always `null` | +| Datadog | `{ "ApiKey": "TheKey12345", "Uri": "https://api.us5.datadoghq.com/api/v1/events"}` | Always `null` | +| Teams | **Initiated**: `null`
    **In Progress**:
    `{ "TenantID": "tenant", "Teams": ["Id": "team", DisplayName: "MyTeam"]}`
    **Completed**:
    `{ "TenantID": "tenant", "Teams": ["Id": "team", DisplayName: "MyTeam"], "ServiceUrl":"https://example.com", ChannelId: "channel-1234"}` | Always `null` | + ## Filtering In addition to the ability to configure integrations mentioned above, organization admins can @@ -292,33 +295,60 @@ graph TD ``` ## Caching -To reduce database load and improve performance, integration configurations are cached in-memory as a Dictionary -with a periodic load of all configurations. Without caching, each incoming `EventMessage` would trigger a database +To reduce database load and improve performance, event integrations uses its own named extended cache (see +[CACHING in Utilities](https://github.com/bitwarden/server/blob/main/src/Core/Utilities/CACHING.md) +for more information). Without caching, for instance, each incoming `EventMessage` would trigger a database query to retrieve the relevant `OrganizationIntegrationConfigurationDetails`. -By loading all configurations into memory on a fixed interval, we ensure: +### `EventIntegrationsCacheConstants` -- Consistent performance for reads. -- Reduced database pressure. -- Predictable refresh timing, independent of event activity. +`EventIntegrationsCacheConstants` allows the code to have strongly typed references to a number of cache-related +details when working with the extended cache. The cache name and all cache keys and tags are programmatically accessed +from `EventIntegrationsCacheConstants` rather than simple strings. For instance, +`EventIntegrationsCacheConstants.CacheName` is used in the cache setup, keyed services, dependency injection, etc., +rather than using a string literal (i.e. "EventIntegrations") in code. -### Architecture / Design +### `OrganizationIntegrationConfigurationDetails` -- The cache is read-only for consumers. It is only updated in bulk by a background refresh process. -- The cache is fully replaced on each refresh to avoid locking or partial state. +- This is one of the most actively used portions of the architecture because any event that has an associated + organization requires a check of the configurations to determine if we need to fire off an integration. +- By using the extended cache, all reads are hitting the L1 or L2 cache before needing to access the database. - Reads return a `List` for a given key or an empty list if no match exists. -- Failures or delays in the loading process do not affect the existing cache state. The cache will continue serving - the last known good state until the update replaces the whole cache. +- The TTL is set very high on these records (1 day). This is because when the admin API makes any changes, it + tells the cache to remove that key. This propagates to the event listening code via the extended cache backplane, + which means that the cache is then expired and the next read will fetch the new values. This allows us to have + a high TTL and avoid needing to refresh values except when necessary. -### Background Refresh +#### Tagging per integration -A hosted service (`IntegrationConfigurationDetailsCacheService`) runs in the background and: +- Each entry in the cache (which again, returns `List`) is tagged with + the organization id and the integration type. +- This allows us to remove all of a given organization's configuration details for an integration when the admin + makes changes at the integration level. + - For instance, if there were 5 events configured for a given organization's webhook and the admin changed the URL + at the integration level, the updates would need to be propagated or else the cache will continue returning the + stale URL. + - By tagging each of the entries, the API can ask the extended cache to remove all the entries for a given + organization integration in one call. The cache will handle dropping / refreshing these entries in a + performant way. +- There are two places in the code that are both aware of the tagging functionality + - The `EventIntegrationHandler` must use the tag when fetching relevant configuration details. This tells the cache + to store the entry with the tag when it successfully loads from the repository. + - The `CreateOrganizationIntegrationCommand`, `UpdateOrganizationIntegrationCommand`, and + `DeleteOrganizationIntegrationCommand` commands need to use the tag to remove all the tagged entries when an admin + creates, updates, or deletes an integration. + - To ensure both places are synchronized on how to tag entries, they both use + `EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration` to build the tag. -- Loads all configuration records at application startup. -- Refreshes the cache on a configurable interval. -- Logs timing and entry count on success. -- Logs exceptions on failure without disrupting application flow. +### Template Properties + +- The `IntegrationTemplateProcessor` supports some properties that require an additional lookup. For instance, + the `UserId` is provided as part of the `EventMessage`, but `UserName` means an additional lookup to map the user + id to the actual name. +- The properties for a `User` (which includes `ActingUser`), `Group`, and `Organization` are cached via the + extended cache with a default TTL of 30 minutes. +- This is cached in both the L1 (Memory) and L2 (Redis) and will be automatically refreshed as needed. # Building a new integration @@ -349,10 +379,20 @@ and event type. - This will be the deserialized version of the `MergedConfiguration` in `OrganizationIntegrationConfigurationDetails`. +A new row with the new integration should be added to this doc in the table above [Existing integrations +and the configurations at each level](#existing-integrations-and-the-configurations-at-each-level). + ## Request Models 1. Add a new case to the switch method in `OrganizationIntegrationRequestModel.Validate`. + - Additionally, add tests in `OrganizationIntegrationRequestModelTests` 2. Add a new case to the switch method in `OrganizationIntegrationConfigurationRequestModel.IsValidForType`. + - Additionally, add / update tests in `OrganizationIntegrationConfigurationRequestModelTests` + +## Response Model + +1. Add a new case to the switch method in `OrganizationIntegrationResponseModel.Status`. + - Additionally, add / update tests in `OrganizationIntegrationResponseModelTests` ## Integration Handler diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegration.cs new file mode 100644 index 0000000000..69a4deb66b --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegration.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public record DatadogIntegration(string ApiKey, Uri Uri); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs similarity index 54% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs index 07aafa4bd8..ed91c3828b 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogIntegrationConfigurationDetails.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record DatadogIntegrationConfigurationDetails(string ApiKey, Uri Uri); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs similarity index 91% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs index 1c74826791..ce35e29927 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/DatadogListenerConfiguration.cs @@ -1,7 +1,7 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class DatadogListenerConfiguration(GlobalSettings globalSettings) : ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/HecIntegration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/HecIntegration.cs similarity index 58% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/HecIntegration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/HecIntegration.cs index 33ae5dadbe..df943e0bfc 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/HecIntegration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/HecIntegration.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record HecIntegration(Uri Uri, string Scheme, string Token, string? Service = null); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/HecListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/HecListenerConfiguration.cs similarity index 91% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/HecListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/HecListenerConfiguration.cs index 37a0d68beb..5ceb42be64 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/HecListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/HecListenerConfiguration.cs @@ -1,7 +1,7 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class HecListenerConfiguration(GlobalSettings globalSettings) : ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IEventListenerConfiguration.cs similarity index 80% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IEventListenerConfiguration.cs index 7df1459941..206dc2cc0b 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IEventListenerConfiguration.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public interface IEventListenerConfiguration { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs similarity index 86% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs index 30401bb072..1fbfefa420 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs @@ -1,6 +1,6 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public interface IIntegrationListenerConfiguration : IEventListenerConfiguration { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationMessage.cs similarity index 67% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationMessage.cs index 7a0962d89a..2d333dfee4 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IIntegrationMessage.cs @@ -1,11 +1,12 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public interface IIntegrationMessage { IntegrationType IntegrationType { get; } string MessageId { get; set; } + string? OrganizationId { get; set; } int RetryCount { get; } DateTime? DelayUntilDate { get; } void ApplyRetry(DateTime? handlerDelayUntilDate); diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFailureCategory.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFailureCategory.cs new file mode 100644 index 0000000000..f9d8f2ab68 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFailureCategory.cs @@ -0,0 +1,37 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +/// +/// Categories of event integration failures used for classification and retry logic. +/// +public enum IntegrationFailureCategory +{ + /// + /// Service is temporarily unavailable (503, upstream outage, maintenance). + /// + ServiceUnavailable, + + /// + /// Authentication failed (401, 403, invalid_auth, token issues). + /// + AuthenticationFailed, + + /// + /// Configuration error (invalid config, channel_not_found, etc.). + /// + ConfigurationError, + + /// + /// Rate limited (429, rate_limited). + /// + RateLimited, + + /// + /// Transient error (timeouts, 500, network errors). + /// + TransientError, + + /// + /// Permanent failure unrelated to authentication/config (e.g., unrecoverable payload/format issue). + /// + PermanentFailure +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterGroup.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterGroup.cs similarity index 76% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterGroup.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterGroup.cs index 276ca3a14b..0c129883cf 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterGroup.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterGroup.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class IntegrationFilterGroup { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterOperation.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterOperation.cs similarity index 61% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterOperation.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterOperation.cs index fddf630e26..d98ab1e13e 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterOperation.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterOperation.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public enum IntegrationFilterOperation { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterRule.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterRule.cs similarity index 76% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterRule.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterRule.cs index b5f90f5e63..9ac3ef753e 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationFilterRule.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationFilterRule.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class IntegrationFilterRule { diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResult.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResult.cs new file mode 100644 index 0000000000..bbdce50ec0 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResult.cs @@ -0,0 +1,84 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +/// +/// Represents the result of an integration handler operation, including success status, +/// failure categorization, and retry metadata. Use the factory method +/// for successful operations or for failures with automatic retry-ability +/// determination based on the failure category. +/// +public class IntegrationHandlerResult +{ + /// + /// True if the integration send succeeded, false otherwise. + /// + public bool Success { get; } + + /// + /// The integration message that was processed. + /// + public IIntegrationMessage Message { get; } + + /// + /// Optional UTC date/time indicating when a failed operation should be retried. + /// Will be used by the retry queue to delay re-sending the message. + /// Usually set based on the Retry-After header from rate-limited responses. + /// + public DateTime? DelayUntilDate { get; private init; } + + /// + /// Category of the failure. Null for successful results. + /// + public IntegrationFailureCategory? Category { get; private init; } + + /// + /// Detailed failure reason or error message. Empty for successful results. + /// + public string? FailureReason { get; private init; } + + /// + /// Indicates whether the operation is retryable. + /// Computed from the failure category. + /// + public bool Retryable => Category switch + { + IntegrationFailureCategory.RateLimited => true, + IntegrationFailureCategory.TransientError => true, + IntegrationFailureCategory.ServiceUnavailable => true, + IntegrationFailureCategory.AuthenticationFailed => false, + IntegrationFailureCategory.ConfigurationError => false, + IntegrationFailureCategory.PermanentFailure => false, + null => false, + _ => false + }; + + /// + /// Creates a successful result. + /// + public static IntegrationHandlerResult Succeed(IIntegrationMessage message) + { + return new IntegrationHandlerResult(success: true, message: message); + } + + /// + /// Creates a failed result with a failure category and reason. + /// + public static IntegrationHandlerResult Fail( + IIntegrationMessage message, + IntegrationFailureCategory category, + string failureReason, + DateTime? delayUntil = null) + { + return new IntegrationHandlerResult(success: false, message: message) + { + Category = category, + FailureReason = failureReason, + DelayUntilDate = delayUntil + }; + } + + private IntegrationHandlerResult(bool success, IIntegrationMessage message) + { + Success = success; + Message = message; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationMessage.cs similarity index 89% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationMessage.cs index 11a5229f8c..edf31a2a1f 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationMessage.cs @@ -1,12 +1,13 @@ using System.Text.Json; -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class IntegrationMessage : IIntegrationMessage { public IntegrationType IntegrationType { get; set; } public required string MessageId { get; set; } + public string? OrganizationId { get; set; } public required string RenderedTemplate { get; set; } public int RetryCount { get; set; } = 0; public DateTime? DelayUntilDate { get; set; } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationOAuthState.cs similarity index 95% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationOAuthState.cs index 3b29bbebb4..d75780d6c6 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationOAuthState.cs @@ -1,8 +1,8 @@ using System.Security.Cryptography; using System.Text; -using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Entities; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class IntegrationOAuthState { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContext.cs similarity index 55% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContext.cs index 79a30c3a02..3b527469fa 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContext.cs @@ -1,10 +1,10 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; -using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class IntegrationTemplateContext(EventMessage eventMessage) { @@ -23,16 +23,31 @@ public class IntegrationTemplateContext(EventMessage eventMessage) public Guid? CollectionId => Event.CollectionId; public Guid? GroupId => Event.GroupId; public Guid? PolicyId => Event.PolicyId; + public Guid? IdempotencyId => Event.IdempotencyId; + public Guid? ProviderId => Event.ProviderId; + public Guid? ProviderUserId => Event.ProviderUserId; + public Guid? ProviderOrganizationId => Event.ProviderOrganizationId; + public Guid? InstallationId => Event.InstallationId; + public Guid? SecretId => Event.SecretId; + public Guid? ProjectId => Event.ProjectId; + public Guid? ServiceAccountId => Event.ServiceAccountId; + public Guid? GrantedServiceAccountId => Event.GrantedServiceAccountId; + public string DateIso8601 => Date.ToString("o"); public string EventMessage => JsonSerializer.Serialize(Event); - public User? User { get; set; } + public OrganizationUserUserDetails? User { get; set; } public string? UserName => User?.Name; public string? UserEmail => User?.Email; + public OrganizationUserType? UserType => User?.Type; - public User? ActingUser { get; set; } + public OrganizationUserUserDetails? ActingUser { get; set; } public string? ActingUserName => ActingUser?.Name; public string? ActingUserEmail => ActingUser?.Email; + public OrganizationUserType? ActingUserType => ActingUser?.Type; + + public Group? Group { get; set; } + public string? GroupName => Group?.Name; public Organization? Organization { get; set; } public string? OrganizationName => Organization?.DisplayName(); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/ListenerConfiguration.cs similarity index 94% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/ListenerConfiguration.cs index 40eb2b3e77..2a970ce670 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/ListenerConfiguration.cs @@ -1,6 +1,6 @@ using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public abstract class ListenerConfiguration { diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs b/src/Core/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetails.cs similarity index 95% rename from src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetails.cs index 5fdc760c90..6517ceccf0 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetails.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetails.cs @@ -1,9 +1,8 @@ using System.Text.Json.Nodes; +using Bit.Core.Dirt.Enums; using Bit.Core.Enums; -#nullable enable - -namespace Bit.Core.Models.Data.Organizations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class OrganizationIntegrationConfigurationDetails { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs similarity index 87% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs index 118b3a17fe..20299dd651 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/RepositoryListenerConfiguration.cs @@ -1,6 +1,6 @@ using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class RepositoryListenerConfiguration(GlobalSettings globalSettings) : ListenerConfiguration(globalSettings), IEventListenerConfiguration diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegration.cs new file mode 100644 index 0000000000..fcfd07f574 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegration.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public record SlackIntegration(string Token); diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs new file mode 100644 index 0000000000..164a132e8c --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfiguration.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public record SlackIntegrationConfiguration(string ChannelId); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs similarity index 56% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs index d22f43bb92..b81617118d 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/SlackIntegrationConfigurationDetails.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record SlackIntegrationConfigurationDetails(string ChannelId, string Token); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/SlackListenerConfiguration.cs similarity index 91% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/SlackListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/SlackListenerConfiguration.cs index 7dd834f51e..ef2cf83837 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/SlackListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/SlackListenerConfiguration.cs @@ -1,7 +1,7 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class SlackListenerConfiguration(GlobalSettings globalSettings) : ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegration.cs new file mode 100644 index 0000000000..fcb42a5261 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegration.cs @@ -0,0 +1,12 @@ +using Bit.Core.Dirt.Models.Data.Teams; + +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public record TeamsIntegration( + string TenantId, + IReadOnlyList Teams, + string? ChannelId = null, + Uri? ServiceUrl = null) +{ + public bool IsCompleted => !string.IsNullOrEmpty(ChannelId) && ServiceUrl is not null; +} diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs new file mode 100644 index 0000000000..a890f553f5 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public record TeamsIntegrationConfigurationDetails(string ChannelId, Uri ServiceUrl); diff --git a/src/Core/Dirt/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs new file mode 100644 index 0000000000..4111c96601 --- /dev/null +++ b/src/Core/Dirt/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs @@ -0,0 +1,38 @@ +using Bit.Core.Dirt.Enums; +using Bit.Core.Settings; + +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; + +public class TeamsListenerConfiguration(GlobalSettings globalSettings) : + ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration +{ + public IntegrationType IntegrationType + { + get => IntegrationType.Teams; + } + + public string EventQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsEventsQueueName; + } + + public string IntegrationQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsIntegrationQueueName; + } + + public string IntegrationRetryQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsIntegrationRetryQueueName; + } + + public string EventSubscriptionName + { + get => _globalSettings.EventLogging.AzureServiceBus.TeamsEventSubscriptionName; + } + + public string IntegrationSubscriptionName + { + get => _globalSettings.EventLogging.AzureServiceBus.TeamsIntegrationSubscriptionName; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegration.cs similarity index 57% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegration.cs index dcda4caa92..d12ea16ee1 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegration.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record WebhookIntegration(Uri Uri, string? Scheme = null, string? Token = null); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs similarity index 60% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs index 851bd3f411..8d7bf90e2c 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfiguration.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record WebhookIntegrationConfiguration(Uri Uri, string? Scheme = null, string? Token = null); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs similarity index 62% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs index dba9b1714d..49508f8454 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookIntegrationConfigurationDetails.cs @@ -1,3 +1,3 @@ -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public record WebhookIntegrationConfigurationDetails(Uri Uri, string? Scheme = null, string? Token = null); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs similarity index 91% rename from src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs rename to src/Core/Dirt/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs index 9d5bf811c7..9afc26168c 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs +++ b/src/Core/Dirt/Models/Data/EventIntegrations/WebhookListenerConfiguration.cs @@ -1,7 +1,7 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; using Bit.Core.Settings; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Dirt.Models.Data.EventIntegrations; public class WebhookListenerConfiguration(GlobalSettings globalSettings) : ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration diff --git a/src/Core/AdminConsole/Models/Data/EventMessage.cs b/src/Core/Dirt/Models/Data/EventMessage.cs similarity index 100% rename from src/Core/AdminConsole/Models/Data/EventMessage.cs rename to src/Core/Dirt/Models/Data/EventMessage.cs diff --git a/src/Core/AdminConsole/Models/Data/EventTableEntity.cs b/src/Core/Dirt/Models/Data/EventTableEntity.cs similarity index 100% rename from src/Core/AdminConsole/Models/Data/EventTableEntity.cs rename to src/Core/Dirt/Models/Data/EventTableEntity.cs diff --git a/src/Core/AdminConsole/Models/Data/IEvent.cs b/src/Core/Dirt/Models/Data/IEvent.cs similarity index 100% rename from src/Core/AdminConsole/Models/Data/IEvent.cs rename to src/Core/Dirt/Models/Data/IEvent.cs diff --git a/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs new file mode 100644 index 0000000000..ffef91275a --- /dev/null +++ b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs @@ -0,0 +1,48 @@ +using Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +namespace Bit.Core.Dirt.Reports.Models.Data; + +public class OrganizationReportMetricsData +{ + public Guid OrganizationId { get; set; } + public int? ApplicationCount { get; set; } + public int? ApplicationAtRiskCount { get; set; } + public int? CriticalApplicationCount { get; set; } + public int? CriticalApplicationAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public int? MemberAtRiskCount { get; set; } + public int? CriticalMemberCount { get; set; } + public int? CriticalMemberAtRiskCount { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? CriticalPasswordCount { get; set; } + public int? CriticalPasswordAtRiskCount { get; set; } + + public static OrganizationReportMetricsData From(Guid organizationId, OrganizationReportMetricsRequest? request) + { + if (request == null) + { + return new OrganizationReportMetricsData + { + OrganizationId = organizationId + }; + } + + return new OrganizationReportMetricsData + { + OrganizationId = organizationId, + ApplicationCount = request.ApplicationCount, + ApplicationAtRiskCount = request.ApplicationAtRiskCount, + CriticalApplicationCount = request.CriticalApplicationCount, + CriticalApplicationAtRiskCount = request.CriticalApplicationAtRiskCount, + MemberCount = request.MemberCount, + MemberAtRiskCount = request.MemberAtRiskCount, + CriticalMemberCount = request.CriticalMemberCount, + CriticalMemberAtRiskCount = request.CriticalMemberAtRiskCount, + PasswordCount = request.PasswordCount, + PasswordAtRiskCount = request.PasswordAtRiskCount, + CriticalPasswordCount = request.CriticalPasswordCount, + CriticalPasswordAtRiskCount = request.CriticalPasswordAtRiskCount + }; + } +} diff --git a/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs b/src/Core/Dirt/Models/Data/Slack/SlackApiResponse.cs similarity index 84% rename from src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs rename to src/Core/Dirt/Models/Data/Slack/SlackApiResponse.cs index ede2123f7e..a70e623ae3 100644 --- a/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs +++ b/src/Core/Dirt/Models/Data/Slack/SlackApiResponse.cs @@ -1,8 +1,6 @@ -#nullable enable +using System.Text.Json.Serialization; -using System.Text.Json.Serialization; - -namespace Bit.Core.Models.Slack; +namespace Bit.Core.Dirt.Models.Data.Slack; public abstract class SlackApiResponse { @@ -35,6 +33,12 @@ public class SlackOAuthResponse : SlackApiResponse public SlackTeam Team { get; set; } = new(); } +public class SlackSendMessageResponse : SlackApiResponse +{ + [JsonPropertyName("channel")] + public string Channel { get; set; } = string.Empty; +} + public class SlackTeam { public string Id { get; set; } = string.Empty; diff --git a/src/Core/Dirt/Models/Data/Teams/TeamsApiResponse.cs b/src/Core/Dirt/Models/Data/Teams/TeamsApiResponse.cs new file mode 100644 index 0000000000..b4b6a2542d --- /dev/null +++ b/src/Core/Dirt/Models/Data/Teams/TeamsApiResponse.cs @@ -0,0 +1,41 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.Dirt.Models.Data.Teams; + +/// Represents the response returned by the Microsoft OAuth 2.0 token endpoint. +/// See Microsoft identity platform and OAuth 2.0 +/// authorization code flow. +public class TeamsOAuthResponse +{ + /// The access token issued by Microsoft, used to call the Microsoft Graph API. + [JsonPropertyName("access_token")] + public string AccessToken { get; set; } = string.Empty; +} + +/// Represents the response from the /me/joinedTeams Microsoft Graph API call. +/// See List joined teams - +/// Microsoft Graph v1.0. +public class JoinedTeamsResponse +{ + /// The collection of teams that the user has joined. + [JsonPropertyName("value")] + public List Value { get; set; } = []; +} + +/// Represents a Microsoft Teams team returned by the Graph API. +/// See Team resource type - +/// Microsoft Graph v1.0. +public class TeamInfo +{ + /// The unique identifier of the team. + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + + /// The name of the team. + [JsonPropertyName("displayName")] + public string DisplayName { get; set; } = string.Empty; + + /// The ID of the Microsoft Entra tenant for this team. + [JsonPropertyName("tenantId")] + public string TenantId { get; set; } = string.Empty; +} diff --git a/src/Core/Dirt/Models/Data/Teams/TeamsBotCredentialProvider.cs b/src/Core/Dirt/Models/Data/Teams/TeamsBotCredentialProvider.cs new file mode 100644 index 0000000000..d8740f9e90 --- /dev/null +++ b/src/Core/Dirt/Models/Data/Teams/TeamsBotCredentialProvider.cs @@ -0,0 +1,28 @@ +using Microsoft.Bot.Connector.Authentication; + +namespace Bit.Core.Dirt.Models.Data.Teams; + +public class TeamsBotCredentialProvider(string clientId, string clientSecret) : ICredentialProvider +{ + private const string _microsoftBotFrameworkIssuer = AuthenticationConstants.ToBotFromChannelTokenIssuer; + + public Task IsValidAppIdAsync(string appId) + { + return Task.FromResult(appId == clientId); + } + + public Task GetAppPasswordAsync(string appId) + { + return Task.FromResult(appId == clientId ? clientSecret : null); + } + + public Task IsAuthenticationDisabledAsync() + { + return Task.FromResult(false); + } + + public Task ValidateIssuerAsync(string issuer) + { + return Task.FromResult(issuer == _microsoftBotFrameworkIssuer); + } +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs index f0477806d8..236560487e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs @@ -35,14 +35,28 @@ public class AddOrganizationReportCommand : IAddOrganizationReportCommand throw new BadRequestException(errorMessage); } + var requestMetrics = request.Metrics ?? new OrganizationReportMetricsRequest(); + var organizationReport = new OrganizationReport { OrganizationId = request.OrganizationId, - ReportData = request.ReportData, + ReportData = request.ReportData ?? string.Empty, CreationDate = DateTime.UtcNow, - ContentEncryptionKey = request.ContentEncryptionKey, + ContentEncryptionKey = request.ContentEncryptionKey ?? string.Empty, SummaryData = request.SummaryData, ApplicationData = request.ApplicationData, + ApplicationCount = requestMetrics.ApplicationCount, + ApplicationAtRiskCount = requestMetrics.ApplicationAtRiskCount, + CriticalApplicationCount = requestMetrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = requestMetrics.CriticalApplicationAtRiskCount, + MemberCount = requestMetrics.MemberCount, + MemberAtRiskCount = requestMetrics.MemberAtRiskCount, + CriticalMemberCount = requestMetrics.CriticalMemberCount, + CriticalMemberAtRiskCount = requestMetrics.CriticalMemberAtRiskCount, + PasswordCount = requestMetrics.PasswordCount, + PasswordAtRiskCount = requestMetrics.PasswordAtRiskCount, + CriticalPasswordCount = requestMetrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = requestMetrics.CriticalPasswordAtRiskCount, RevisionDate = DateTime.UtcNow }; diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs index 2a8c0203f9..eecc84c522 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs @@ -1,16 +1,15 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class AddOrganizationReportRequest { public Guid OrganizationId { get; set; } - public string ReportData { get; set; } + public string? ReportData { get; set; } - public string ContentEncryptionKey { get; set; } + public string? ContentEncryptionKey { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } + + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs new file mode 100644 index 0000000000..9403a5f1c2 --- /dev/null +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs @@ -0,0 +1,31 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +public class OrganizationReportMetricsRequest +{ + [JsonPropertyName("totalApplicationCount")] + public int? ApplicationCount { get; set; } = null; + [JsonPropertyName("totalAtRiskApplicationCount")] + public int? ApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalApplicationCount")] + public int? CriticalApplicationCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskApplicationCount")] + public int? CriticalApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalMemberCount")] + public int? MemberCount { get; set; } = null; + [JsonPropertyName("totalAtRiskMemberCount")] + public int? MemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalMemberCount")] + public int? CriticalMemberCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskMemberCount")] + public int? CriticalMemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalPasswordCount")] + public int? PasswordCount { get; set; } = null; + [JsonPropertyName("totalAtRiskPasswordCount")] + public int? PasswordAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalPasswordCount")] + public int? CriticalPasswordCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskPasswordCount")] + public int? CriticalPasswordAtRiskCount { get; set; } = null; +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs index ab4fcc5921..e549a3f120 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs @@ -1,11 +1,8 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportApplicationDataRequest { public Guid Id { get; set; } public Guid OrganizationId { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs index b0e555fcef..27358537c2 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs @@ -1,11 +1,9 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportSummaryRequest { public Guid OrganizationId { get; set; } public Guid ReportId { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs index 67ec49d004..375b766a0e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs @@ -53,7 +53,7 @@ public class UpdateOrganizationReportApplicationDataCommand : IUpdateOrganizatio throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData); + var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report application data {reportId} for organization {organizationId}", request.Id, request.OrganizationId); diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs index 6859814d65..5d0f2670ca 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Dirt.Repositories; @@ -53,7 +54,8 @@ public class UpdateOrganizationReportSummaryCommand : IUpdateOrganizationReportS throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData); + await _organizationReportRepo.UpdateMetricsAsync(request.ReportId, OrganizationReportMetricsData.From(request.OrganizationId, request.Metrics)); + var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report summary {reportId} for organization {organizationId}", request.ReportId, request.OrganizationId); diff --git a/src/Core/AdminConsole/Repositories/IEventRepository.cs b/src/Core/Dirt/Repositories/IEventRepository.cs similarity index 100% rename from src/Core/AdminConsole/Repositories/IEventRepository.cs rename to src/Core/Dirt/Repositories/IEventRepository.cs diff --git a/src/Core/Dirt/Repositories/IOrganizationIntegrationConfigurationRepository.cs b/src/Core/Dirt/Repositories/IOrganizationIntegrationConfigurationRepository.cs new file mode 100644 index 0000000000..f6f90c7c9f --- /dev/null +++ b/src/Core/Dirt/Repositories/IOrganizationIntegrationConfigurationRepository.cs @@ -0,0 +1,32 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Enums; +using Bit.Core.Repositories; + +namespace Bit.Core.Dirt.Repositories; + +public interface IOrganizationIntegrationConfigurationRepository : IRepository +{ + /// + /// Retrieve the list of available configuration details for a specific event for the organization and + /// integration type.
    + ///
    + /// Note: This returns all configurations that match the event type explicitly and + /// all the configurations that have a null event type - null event type is considered a + /// wildcard that matches all events. + /// + ///
    + /// The specific event type + /// The id of the organization + /// The integration type + /// A List of that match + Task> GetManyByEventTypeOrganizationIdIntegrationType( + EventType eventType, + Guid organizationId, + IntegrationType integrationType); + + Task> GetAllConfigurationDetailsAsync(); + + Task> GetManyByIntegrationAsync(Guid organizationIntegrationId); +} diff --git a/src/Core/Dirt/Repositories/IOrganizationIntegrationRepository.cs b/src/Core/Dirt/Repositories/IOrganizationIntegrationRepository.cs new file mode 100644 index 0000000000..03775e8d20 --- /dev/null +++ b/src/Core/Dirt/Repositories/IOrganizationIntegrationRepository.cs @@ -0,0 +1,11 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Repositories; + +namespace Bit.Core.Dirt.Repositories; + +public interface IOrganizationIntegrationRepository : IRepository +{ + Task> GetManyByOrganizationAsync(Guid organizationId); + + Task GetByTeamsConfigurationTenantIdTeamId(string tenantId, string teamId); +} diff --git a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs index 9687173716..b4c2f90566 100644 --- a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs +++ b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs @@ -1,5 +1,6 @@ using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Repositories; namespace Bit.Core.Dirt.Repositories; @@ -21,5 +22,8 @@ public interface IOrganizationReportRepository : IRepository GetApplicationDataAsync(Guid reportId); Task UpdateApplicationDataAsync(Guid orgId, Guid reportId, string applicationData); + + // Metrics methods + Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics); } diff --git a/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs b/src/Core/Dirt/Repositories/TableStorage/EventRepository.cs similarity index 100% rename from src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs rename to src/Core/Dirt/Repositories/TableStorage/EventRepository.cs diff --git a/src/Core/AdminConsole/Services/IAzureServiceBusService.cs b/src/Core/Dirt/Services/IAzureServiceBusService.cs similarity index 77% rename from src/Core/AdminConsole/Services/IAzureServiceBusService.cs rename to src/Core/Dirt/Services/IAzureServiceBusService.cs index 75864255c2..6b425511ab 100644 --- a/src/Core/AdminConsole/Services/IAzureServiceBusService.cs +++ b/src/Core/Dirt/Services/IAzureServiceBusService.cs @@ -1,7 +1,7 @@ using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services; public interface IAzureServiceBusService : IEventIntegrationPublisher, IAsyncDisposable { diff --git a/src/Core/Dirt/Services/IEventIntegrationPublisher.cs b/src/Core/Dirt/Services/IEventIntegrationPublisher.cs new file mode 100644 index 0000000000..583c2448fe --- /dev/null +++ b/src/Core/Dirt/Services/IEventIntegrationPublisher.cs @@ -0,0 +1,9 @@ +using Bit.Core.Dirt.Models.Data.EventIntegrations; + +namespace Bit.Core.Dirt.Services; + +public interface IEventIntegrationPublisher : IAsyncDisposable +{ + Task PublishAsync(IIntegrationMessage message); + Task PublishEventAsync(string body, string? organizationId); +} diff --git a/src/Core/AdminConsole/Services/IEventMessageHandler.cs b/src/Core/Dirt/Services/IEventMessageHandler.cs similarity index 85% rename from src/Core/AdminConsole/Services/IEventMessageHandler.cs rename to src/Core/Dirt/Services/IEventMessageHandler.cs index 83c5e33ecb..9b1385129b 100644 --- a/src/Core/AdminConsole/Services/IEventMessageHandler.cs +++ b/src/Core/Dirt/Services/IEventMessageHandler.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Data; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services; public interface IEventMessageHandler { diff --git a/src/Core/AdminConsole/Services/IEventWriteService.cs b/src/Core/Dirt/Services/IEventWriteService.cs similarity index 100% rename from src/Core/AdminConsole/Services/IEventWriteService.cs rename to src/Core/Dirt/Services/IEventWriteService.cs diff --git a/src/Core/AdminConsole/Services/IIntegrationFilterService.cs b/src/Core/Dirt/Services/IIntegrationFilterService.cs similarity index 67% rename from src/Core/AdminConsole/Services/IIntegrationFilterService.cs rename to src/Core/Dirt/Services/IIntegrationFilterService.cs index 5bc035d468..f46ab83f54 100644 --- a/src/Core/AdminConsole/Services/IIntegrationFilterService.cs +++ b/src/Core/Dirt/Services/IIntegrationFilterService.cs @@ -1,9 +1,9 @@ #nullable enable -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Models.Data; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services; public interface IIntegrationFilterService { diff --git a/src/Core/Dirt/Services/IIntegrationHandler.cs b/src/Core/Dirt/Services/IIntegrationHandler.cs new file mode 100644 index 0000000000..81103b453d --- /dev/null +++ b/src/Core/Dirt/Services/IIntegrationHandler.cs @@ -0,0 +1,115 @@ +using System.Globalization; +using System.Net; +using Bit.Core.Dirt.Models.Data.EventIntegrations; + +namespace Bit.Core.Dirt.Services; + +public interface IIntegrationHandler +{ + Task HandleAsync(string json); +} + +public interface IIntegrationHandler : IIntegrationHandler +{ + Task HandleAsync(IntegrationMessage message); +} + +public abstract class IntegrationHandlerBase : IIntegrationHandler +{ + public async Task HandleAsync(string json) + { + var message = IntegrationMessage.FromJson(json); + return await HandleAsync(message ?? throw new ArgumentException("IntegrationMessage was null when created from the provided JSON")); + } + + public abstract Task HandleAsync(IntegrationMessage message); + + protected IntegrationHandlerResult ResultFromHttpResponse( + HttpResponseMessage response, + IntegrationMessage message, + TimeProvider timeProvider) + { + if (response.IsSuccessStatusCode) + { + return IntegrationHandlerResult.Succeed(message); + } + + var category = ClassifyHttpStatusCode(response.StatusCode); + var failureReason = response.ReasonPhrase ?? $"Failure with status code {(int)response.StatusCode}"; + + if (category is not (IntegrationFailureCategory.RateLimited + or IntegrationFailureCategory.TransientError + or IntegrationFailureCategory.ServiceUnavailable) || + !response.Headers.TryGetValues("Retry-After", out var values) + ) + { + return IntegrationHandlerResult.Fail(message: message, category: category, failureReason: failureReason); + } + + // Handle Retry-After header for rate-limited and retryable errors + DateTime? delayUntil = null; + var value = values.FirstOrDefault(); + if (int.TryParse(value, out var seconds)) + { + // Retry-after was specified in seconds + delayUntil = timeProvider.GetUtcNow().AddSeconds(seconds).UtcDateTime; + } + else if (DateTimeOffset.TryParseExact(value, + "r", // "r" is the round-trip format: RFC1123 + CultureInfo.InvariantCulture, + DateTimeStyles.AssumeUniversal | DateTimeStyles.AdjustToUniversal, + out var retryDate)) + { + // Retry-after was specified as a date + delayUntil = retryDate.UtcDateTime; + } + + return IntegrationHandlerResult.Fail( + message, + category, + failureReason, + delayUntil + ); + } + + /// + /// Classifies an as an to drive + /// retry behavior and operator-facing failure reporting. + /// + /// The HTTP status code. + /// The corresponding . + protected static IntegrationFailureCategory ClassifyHttpStatusCode(HttpStatusCode statusCode) + { + var explicitCategory = statusCode switch + { + HttpStatusCode.Unauthorized => IntegrationFailureCategory.AuthenticationFailed, + HttpStatusCode.Forbidden => IntegrationFailureCategory.AuthenticationFailed, + HttpStatusCode.NotFound => IntegrationFailureCategory.ConfigurationError, + HttpStatusCode.Gone => IntegrationFailureCategory.ConfigurationError, + HttpStatusCode.MovedPermanently => IntegrationFailureCategory.ConfigurationError, + HttpStatusCode.TemporaryRedirect => IntegrationFailureCategory.ConfigurationError, + HttpStatusCode.PermanentRedirect => IntegrationFailureCategory.ConfigurationError, + HttpStatusCode.TooManyRequests => IntegrationFailureCategory.RateLimited, + HttpStatusCode.RequestTimeout => IntegrationFailureCategory.TransientError, + HttpStatusCode.InternalServerError => IntegrationFailureCategory.TransientError, + HttpStatusCode.BadGateway => IntegrationFailureCategory.TransientError, + HttpStatusCode.GatewayTimeout => IntegrationFailureCategory.TransientError, + HttpStatusCode.ServiceUnavailable => IntegrationFailureCategory.ServiceUnavailable, + HttpStatusCode.NotImplemented => IntegrationFailureCategory.PermanentFailure, + _ => (IntegrationFailureCategory?)null + }; + + if (explicitCategory is not null) + { + return explicitCategory.Value; + } + + return (int)statusCode switch + { + >= 300 and <= 399 => IntegrationFailureCategory.ConfigurationError, + >= 400 and <= 499 => IntegrationFailureCategory.ConfigurationError, + >= 500 and <= 599 => IntegrationFailureCategory.ServiceUnavailable, + _ => IntegrationFailureCategory.ServiceUnavailable + }; + } +} diff --git a/src/Core/Dirt/Services/IOrganizationIntegrationConfigurationValidator.cs b/src/Core/Dirt/Services/IOrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..4a3a089f26 --- /dev/null +++ b/src/Core/Dirt/Services/IOrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,17 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; + +namespace Bit.Core.Dirt.Services; + +public interface IOrganizationIntegrationConfigurationValidator +{ + /// + /// Validates that the configuration is valid for the given integration type. The configuration must + /// include a Configuration that is valid for the type, valid Filters, and a non-empty Template + /// to pass validation. + /// + /// The type of integration + /// The OrganizationIntegrationConfiguration to validate + /// True if valid, false otherwise + bool ValidateConfiguration(IntegrationType integrationType, OrganizationIntegrationConfiguration configuration); +} diff --git a/src/Core/AdminConsole/Services/IRabbitMqService.cs b/src/Core/Dirt/Services/IRabbitMqService.cs similarity index 89% rename from src/Core/AdminConsole/Services/IRabbitMqService.cs rename to src/Core/Dirt/Services/IRabbitMqService.cs index 12c40c3b98..b9f824506f 100644 --- a/src/Core/AdminConsole/Services/IRabbitMqService.cs +++ b/src/Core/Dirt/Services/IRabbitMqService.cs @@ -1,8 +1,8 @@ -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using RabbitMQ.Client; using RabbitMQ.Client.Events; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services; public interface IRabbitMqService : IEventIntegrationPublisher { diff --git a/src/Core/Dirt/Services/ISlackService.cs b/src/Core/Dirt/Services/ISlackService.cs new file mode 100644 index 0000000000..111fcb5440 --- /dev/null +++ b/src/Core/Dirt/Services/ISlackService.cs @@ -0,0 +1,62 @@ +using Bit.Core.Dirt.Models.Data.Slack; +using Bit.Core.Dirt.Services.Implementations; + +namespace Bit.Core.Dirt.Services; + +/// Defines operations for interacting with Slack, including OAuth authentication, channel discovery, +/// and sending messages. +public interface ISlackService +{ + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up channels for a user. + /// Retrieves the ID of a Slack channel by name. + /// See conversations.list API. + /// A valid Slack OAuth access token. + /// The name of the channel to look up. + /// The channel ID if found; otherwise, an empty string. + Task GetChannelIdAsync(string token, string channelName); + + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up channels for a user. + /// Retrieves the IDs of multiple Slack channels by name. + /// See conversations.list API. + /// A valid Slack OAuth access token. + /// A list of channel names to look up. + /// A list of matching channel IDs. Channels that cannot be found are omitted. + Task> GetChannelIdsAsync(string token, List channelNames); + + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up a user by their email address. + /// Retrieves the DM channel ID for a Slack user by email. + /// See users.lookupByEmail API and + /// conversations.open API. + /// A valid Slack OAuth access token. + /// The email address of the user to open a DM with. + /// The DM channel ID if successful; otherwise, an empty string. + Task GetDmChannelByEmailAsync(string token, string email); + + /// Builds the Slack OAuth 2.0 authorization URL for the app. + /// See Slack OAuth v2 documentation. + /// The absolute redirect URI that Slack will call after user authorization. + /// Must match the URI registered with the app configuration. + /// A state token used to correlate the request and callback and prevent CSRF attacks. + /// The full authorization URL to which the user should be redirected to begin the sign-in process. + string GetRedirectUrl(string callbackUrl, string state); + + /// Exchanges a Slack OAuth code for an access token. + /// See oauth.v2.access API. + /// The authorization code returned by Slack via the callback URL after user authorization. + /// The redirect URI that was used in the authorization request. + /// A valid Slack access token if successful; otherwise, an empty string. + Task ObtainTokenViaOAuth(string code, string redirectUrl); + + /// Sends a message to a Slack channel by ID. + /// See chat.postMessage API. + /// This is used primarily by the to send events to the + /// Slack channel. + /// A valid Slack OAuth access token. + /// The message text to send. + /// The channel ID to send the message to. + /// The response from Slack after sending the message. + Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId); +} diff --git a/src/Core/Dirt/Services/ITeamsService.cs b/src/Core/Dirt/Services/ITeamsService.cs new file mode 100644 index 0000000000..30a324f9a4 --- /dev/null +++ b/src/Core/Dirt/Services/ITeamsService.cs @@ -0,0 +1,50 @@ +using Bit.Core.Dirt.Models.Data.Teams; +using Bit.Core.Dirt.Services.Implementations; + +namespace Bit.Core.Dirt.Services; + +/// +/// Service that provides functionality relating to the Microsoft Teams integration including OAuth, +/// team discovery and sending a message to a channel in Teams. +/// +public interface ITeamsService +{ + /// + /// Generate the Microsoft Teams OAuth 2.0 authorization URL used to begin the sign-in flow. + /// + /// The absolute redirect URI that Microsoft will call after user authorization. + /// Must match the URI registered with the app configuration. + /// A state token used to correlate the request and callback and prevent CSRF attacks. + /// The full authorization URL to which the user should be redirected to begin the sign-in process. + string GetRedirectUrl(string callbackUrl, string state); + + /// + /// Exchange the OAuth code for a Microsoft Graph API access token. + /// + /// The code returned from Microsoft via the OAuth callback Url. + /// The same redirect URI that was passed to the authorization request. + /// A valid Microsoft Graph access token if the exchange succeeds; otherwise, an empty string. + Task ObtainTokenViaOAuth(string code, string redirectUrl); + + /// + /// Get the Teams to which the authenticated user belongs via Microsoft Graph API. + /// + /// A valid Microsoft Graph access token for the user (obtained via OAuth). + /// A read-only list of objects representing the user’s joined teams. + /// Returns an empty list if the request fails or if the token is invalid. + Task> GetJoinedTeamsAsync(string accessToken); + + /// + /// Send a message to a specific channel in Teams. + /// + /// This is used primarily by the to send events to the + /// Teams channel. + /// The service URI associated with the Microsoft Bot Framework connector for the target + /// team. Obtained via the bot framework callback. + /// The conversation or channel ID where the message should be delivered. Obtained via + /// the bot framework callback. + /// The message text to post to the channel. + /// A task that completes when the message has been sent. Errors during message delivery are surfaced + /// as exceptions from the underlying connector client. + Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message); +} diff --git a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs b/src/Core/Dirt/Services/Implementations/AzureQueueEventWriteService.cs similarity index 93% rename from src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs rename to src/Core/Dirt/Services/Implementations/AzureQueueEventWriteService.cs index f81175f7b5..4f48b64b5a 100644 --- a/src/Core/AdminConsole/Services/Implementations/AzureQueueEventWriteService.cs +++ b/src/Core/Dirt/Services/Implementations/AzureQueueEventWriteService.cs @@ -8,7 +8,7 @@ namespace Bit.Core.Services; public class AzureQueueEventWriteService : AzureQueueService, IEventWriteService { public AzureQueueEventWriteService(GlobalSettings globalSettings) : base( - new QueueClient(globalSettings.Events.ConnectionString, "event"), + new QueueClient(globalSettings.Events.ConnectionString, globalSettings.Events.QueueName), JsonHelpers.IgnoreWritingNull) { } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs b/src/Core/Dirt/Services/Implementations/AzureServiceBusEventListenerService.cs similarity index 89% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs rename to src/Core/Dirt/Services/Implementations/AzureServiceBusEventListenerService.cs index a589211687..6175374e2f 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs +++ b/src/Core/Dirt/Services/Implementations/AzureServiceBusEventListenerService.cs @@ -1,9 +1,9 @@ using System.Text; using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class AzureServiceBusEventListenerService : EventLoggingListenerService where TConfiguration : IEventListenerConfiguration @@ -42,7 +42,7 @@ public class AzureServiceBusEventListenerService : EventLoggingL private static ILogger CreateLogger(ILoggerFactory loggerFactory, TConfiguration configuration) { return loggerFactory.CreateLogger( - categoryName: $"Bit.Core.Services.AzureServiceBusEventListenerService.{configuration.EventSubscriptionName}"); + categoryName: $"Bit.Core.Dirt.Services.Implementations.AzureServiceBusEventListenerService.{configuration.EventSubscriptionName}"); } internal Task ProcessErrorAsync(ProcessErrorEventArgs args) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs b/src/Core/Dirt/Services/Implementations/AzureServiceBusIntegrationListenerService.cs similarity index 80% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs rename to src/Core/Dirt/Services/Implementations/AzureServiceBusIntegrationListenerService.cs index 633a53296b..32132ddb37 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs +++ b/src/Core/Dirt/Services/Implementations/AzureServiceBusIntegrationListenerService.cs @@ -1,9 +1,9 @@ using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class AzureServiceBusIntegrationListenerService : BackgroundService where TConfiguration : IIntegrationListenerConfiguration @@ -23,7 +23,7 @@ public class AzureServiceBusIntegrationListenerService : Backgro { _handler = handler; _logger = loggerFactory.CreateLogger( - categoryName: $"Bit.Core.Services.AzureServiceBusIntegrationListenerService.{configuration.IntegrationSubscriptionName}"); + categoryName: $"Bit.Core.Dirt.Services.Implementations.AzureServiceBusIntegrationListenerService.{configuration.IntegrationSubscriptionName}"); _maxRetries = configuration.MaxRetries; _serviceBusService = serviceBusService; @@ -85,6 +85,17 @@ public class AzureServiceBusIntegrationListenerService : Backgro { // Non-recoverable failure or exceeded the max number of retries // Return false to indicate this message should be dead-lettered + _logger.LogWarning( + "Integration failure - non-recoverable error or max retries exceeded. " + + "MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " + + "FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}", + message.MessageId, + message.IntegrationType, + message.OrganizationId, + result.Category, + result.FailureReason, + message.RetryCount, + _maxRetries); return false; } } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs b/src/Core/Dirt/Services/Implementations/AzureServiceBusService.cs similarity index 80% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs rename to src/Core/Dirt/Services/Implementations/AzureServiceBusService.cs index 4887aa3a7f..7b87850fe3 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs +++ b/src/Core/Dirt/Services/Implementations/AzureServiceBusService.cs @@ -1,9 +1,9 @@ using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Settings; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class AzureServiceBusService : IAzureServiceBusService { @@ -30,7 +30,8 @@ public class AzureServiceBusService : IAzureServiceBusService var serviceBusMessage = new ServiceBusMessage(json) { Subject = message.IntegrationType.ToRoutingKey(), - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); @@ -44,18 +45,20 @@ public class AzureServiceBusService : IAzureServiceBusService { Subject = message.IntegrationType.ToRoutingKey(), ScheduledEnqueueTime = message.DelayUntilDate ?? DateTime.UtcNow, - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { var message = new ServiceBusMessage(body) { ContentType = "application/json", - MessageId = Guid.NewGuid().ToString() + MessageId = Guid.NewGuid().ToString(), + PartitionKey = organizationId }; await _eventSender.SendMessageAsync(message); diff --git a/src/Core/AdminConsole/Services/Implementations/AzureTableStorageEventHandler.cs b/src/Core/Dirt/Services/Implementations/AzureTableStorageEventHandler.cs similarity index 84% rename from src/Core/AdminConsole/Services/Implementations/AzureTableStorageEventHandler.cs rename to src/Core/Dirt/Services/Implementations/AzureTableStorageEventHandler.cs index 578dde9485..73d22b21a7 100644 --- a/src/Core/AdminConsole/Services/Implementations/AzureTableStorageEventHandler.cs +++ b/src/Core/Dirt/Services/Implementations/AzureTableStorageEventHandler.cs @@ -1,9 +1,8 @@ -#nullable enable - -using Bit.Core.Models.Data; +using Bit.Core.Models.Data; +using Bit.Core.Services; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class AzureTableStorageEventHandler( [FromKeyedServices("persistent")] IEventWriteService eventWriteService) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/DatadogIntegrationHandler.cs b/src/Core/Dirt/Services/Implementations/DatadogIntegrationHandler.cs similarity index 90% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/DatadogIntegrationHandler.cs rename to src/Core/Dirt/Services/Implementations/DatadogIntegrationHandler.cs index 45bb5b6d7d..e5c684ceec 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/DatadogIntegrationHandler.cs +++ b/src/Core/Dirt/Services/Implementations/DatadogIntegrationHandler.cs @@ -1,7 +1,7 @@ using System.Text; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class DatadogIntegrationHandler( IHttpClientFactory httpClientFactory, diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs b/src/Core/Dirt/Services/Implementations/EventIntegrationEventWriteService.cs similarity index 65% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs rename to src/Core/Dirt/Services/Implementations/EventIntegrationEventWriteService.cs index 309b4a8409..44e0513ee0 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs +++ b/src/Core/Dirt/Services/Implementations/EventIntegrationEventWriteService.cs @@ -1,7 +1,8 @@ using System.Text.Json; using Bit.Core.Models.Data; +using Bit.Core.Services; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class EventIntegrationEventWriteService : IEventWriteService, IAsyncDisposable { private readonly IEventIntegrationPublisher _eventIntegrationPublisher; @@ -14,15 +15,21 @@ public class EventIntegrationEventWriteService : IEventWriteService, IAsyncDispo public async Task CreateAsync(IEvent e) { var body = JsonSerializer.Serialize(e); - await _eventIntegrationPublisher.PublishEventAsync(body: body); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: e.OrganizationId?.ToString()); } public async Task CreateManyAsync(IEnumerable events) { - var body = JsonSerializer.Serialize(events); - await _eventIntegrationPublisher.PublishEventAsync(body: body); - } + var eventList = events as IList ?? events.ToList(); + if (eventList.Count == 0) + { + return; + } + var organizationId = eventList[0].OrganizationId?.ToString(); + var body = JsonSerializer.Serialize(eventList); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: organizationId); + } public async ValueTask DisposeAsync() { await _eventIntegrationPublisher.DisposeAsync(); diff --git a/src/Core/Dirt/Services/Implementations/EventIntegrationHandler.cs b/src/Core/Dirt/Services/Implementations/EventIntegrationHandler.cs new file mode 100644 index 0000000000..bcd1f1dd8c --- /dev/null +++ b/src/Core/Dirt/Services/Implementations/EventIntegrationHandler.cs @@ -0,0 +1,165 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Dirt.Services.Implementations; + +public class EventIntegrationHandler( + IntegrationType integrationType, + IEventIntegrationPublisher eventIntegrationPublisher, + IIntegrationFilterService integrationFilterService, + IFusionCache cache, + IOrganizationIntegrationConfigurationRepository configurationRepository, + IGroupRepository groupRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + ILogger> logger) + : IEventMessageHandler +{ + public async Task HandleEventAsync(EventMessage eventMessage) + { + foreach (var configuration in await GetConfigurationDetailsListAsync(eventMessage)) + { + try + { + if (configuration.Filters is string filterJson) + { + // Evaluate filters - if false, then discard and do not process + var filters = JsonSerializer.Deserialize(filterJson) + ?? throw new InvalidOperationException($"Failed to deserialize Filters to FilterGroup"); + if (!integrationFilterService.EvaluateFilterGroup(filters, eventMessage)) + { + continue; + } + } + + // Valid filter - assemble message and publish to Integration topic/exchange + var template = configuration.Template ?? string.Empty; + var context = await BuildContextAsync(eventMessage, template); + var renderedTemplate = IntegrationTemplateProcessor.ReplaceTokens(template, context); + var messageId = eventMessage.IdempotencyId ?? Guid.NewGuid(); + var config = configuration.MergedConfiguration.Deserialize() + ?? throw new InvalidOperationException($"Failed to deserialize to {typeof(T).Name} - bad Configuration"); + + var message = new IntegrationMessage + { + IntegrationType = integrationType, + MessageId = messageId.ToString(), + OrganizationId = eventMessage.OrganizationId?.ToString(), + Configuration = config, + RenderedTemplate = renderedTemplate, + RetryCount = 0, + DelayUntilDate = null + }; + + await eventIntegrationPublisher.PublishAsync(message); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to publish Integration Message for {Type}, check Id {RecordId} for error in Configuration or Filters", + typeof(T).Name, + configuration.Id); + } + } + } + + public async Task HandleManyEventsAsync(IEnumerable eventMessages) + { + foreach (var eventMessage in eventMessages) + { + await HandleEventAsync(eventMessage); + } + } + + internal async Task BuildContextAsync(EventMessage eventMessage, string template) + { + // Note: All of these cache calls use the default options, including TTL of 30 minutes + + var context = new IntegrationTemplateContext(eventMessage); + + if (IntegrationTemplateProcessor.TemplateRequiresGroup(template) && eventMessage.GroupId.HasValue) + { + context.Group = await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForGroup(eventMessage.GroupId.Value), + factory: async _ => await groupRepository.GetByIdAsync(eventMessage.GroupId.Value) + ); + } + + if (eventMessage.OrganizationId is not Guid organizationId) + { + return context; + } + + if (IntegrationTemplateProcessor.TemplateRequiresUser(template) && eventMessage.UserId.HasValue) + { + context.User = await GetUserFromCacheAsync(organizationId, eventMessage.UserId.Value); + } + + if (IntegrationTemplateProcessor.TemplateRequiresActingUser(template) && eventMessage.ActingUserId.HasValue) + { + context.ActingUser = await GetUserFromCacheAsync(organizationId, eventMessage.ActingUserId.Value); + } + + if (IntegrationTemplateProcessor.TemplateRequiresOrganization(template)) + { + context.Organization = await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganization(organizationId), + factory: async _ => await organizationRepository.GetByIdAsync(organizationId) + ); + } + + return context; + } + + private async Task> GetConfigurationDetailsListAsync(EventMessage eventMessage) + { + if (eventMessage.OrganizationId is not Guid organizationId) + { + return []; + } + + List configurations = []; + + var integrationTag = EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integrationType + ); + + configurations.AddRange(await cache.GetOrSetAsync>( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId: organizationId, + integrationType: integrationType, + eventType: eventMessage.Type), + factory: async _ => await configurationRepository.GetManyByEventTypeOrganizationIdIntegrationType( + eventType: eventMessage.Type, + organizationId: organizationId, + integrationType: integrationType), + options: new FusionCacheEntryOptions( + duration: EventIntegrationsCacheConstants.DurationForOrganizationIntegrationConfigurationDetails), + tags: [integrationTag] + )); + + return configurations; + } + + private async Task GetUserFromCacheAsync(Guid organizationId, Guid userId) => + await cache.GetOrSetAsync( + key: EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationUser(organizationId, userId), + factory: async _ => await organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + organizationId: organizationId, + userId: userId + ) + ); +} diff --git a/src/Core/AdminConsole/Services/EventLoggingListenerService.cs b/src/Core/Dirt/Services/Implementations/EventLoggingListenerService.cs similarity index 97% rename from src/Core/AdminConsole/Services/EventLoggingListenerService.cs rename to src/Core/Dirt/Services/Implementations/EventLoggingListenerService.cs index 84a862ce94..29e3f8dec3 100644 --- a/src/Core/AdminConsole/Services/EventLoggingListenerService.cs +++ b/src/Core/Dirt/Services/Implementations/EventLoggingListenerService.cs @@ -1,11 +1,9 @@ -#nullable enable - -using System.Text.Json; +using System.Text.Json; using Bit.Core.Models.Data; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public abstract class EventLoggingListenerService : BackgroundService { diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRepositoryHandler.cs b/src/Core/Dirt/Services/Implementations/EventRepositoryHandler.cs similarity index 87% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRepositoryHandler.cs rename to src/Core/Dirt/Services/Implementations/EventRepositoryHandler.cs index ee3a2d5db2..32173b8da0 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRepositoryHandler.cs +++ b/src/Core/Dirt/Services/Implementations/EventRepositoryHandler.cs @@ -1,7 +1,8 @@ using Bit.Core.Models.Data; +using Bit.Core.Services; using Microsoft.Extensions.DependencyInjection; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class EventRepositoryHandler( [FromKeyedServices("persistent")] IEventWriteService eventWriteService) diff --git a/src/Core/AdminConsole/Services/Implementations/EventService.cs b/src/Core/Dirt/Services/Implementations/EventService.cs similarity index 100% rename from src/Core/AdminConsole/Services/Implementations/EventService.cs rename to src/Core/Dirt/Services/Implementations/EventService.cs diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterFactory.cs b/src/Core/Dirt/Services/Implementations/IntegrationFilterFactory.cs similarity index 97% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterFactory.cs rename to src/Core/Dirt/Services/Implementations/IntegrationFilterFactory.cs index d28ac910b7..8c25c80208 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterFactory.cs +++ b/src/Core/Dirt/Services/Implementations/IntegrationFilterFactory.cs @@ -1,7 +1,7 @@ using System.Linq.Expressions; using Bit.Core.Models.Data; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public delegate bool IntegrationFilter(EventMessage message, object? value); diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterService.cs b/src/Core/Dirt/Services/Implementations/IntegrationFilterService.cs similarity index 97% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterService.cs rename to src/Core/Dirt/Services/Implementations/IntegrationFilterService.cs index 1c8fae4000..7d56b7c7ce 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/IntegrationFilterService.cs +++ b/src/Core/Dirt/Services/Implementations/IntegrationFilterService.cs @@ -1,8 +1,8 @@ using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Models.Data; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class IntegrationFilterService : IIntegrationFilterService { diff --git a/src/Core/Dirt/Services/Implementations/OrganizationIntegrationConfigurationValidator.cs b/src/Core/Dirt/Services/Implementations/OrganizationIntegrationConfigurationValidator.cs new file mode 100644 index 0000000000..7b6ab320b8 --- /dev/null +++ b/src/Core/Dirt/Services/Implementations/OrganizationIntegrationConfigurationValidator.cs @@ -0,0 +1,76 @@ +using System.Text.Json; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; + +namespace Bit.Core.Dirt.Services.Implementations; + +public class OrganizationIntegrationConfigurationValidator : IOrganizationIntegrationConfigurationValidator +{ + public bool ValidateConfiguration(IntegrationType integrationType, + OrganizationIntegrationConfiguration configuration) + { + // Validate template is present + if (string.IsNullOrWhiteSpace(configuration.Template)) + { + return false; + } + // If Filters are present, they must be valid + if (!IsFiltersValid(configuration.Filters)) + { + return false; + } + + switch (integrationType) + { + case IntegrationType.CloudBillingSync or IntegrationType.Scim: + return false; + case IntegrationType.Slack: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Webhook: + return IsConfigurationValid(configuration.Configuration); + case IntegrationType.Hec: + case IntegrationType.Datadog: + case IntegrationType.Teams: + return configuration.Configuration is null; + default: + return false; + } + } + + private static bool IsConfigurationValid(string? configuration) + { + if (string.IsNullOrWhiteSpace(configuration)) + { + return false; + } + + try + { + var config = JsonSerializer.Deserialize(configuration); + return config is not null; + } + catch + { + return false; + } + } + + private static bool IsFiltersValid(string? filters) + { + if (filters is null) + { + return true; + } + + try + { + var filterGroup = JsonSerializer.Deserialize(filters); + return filterGroup is not null; + } + catch + { + return false; + } + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqEventListenerService.cs b/src/Core/Dirt/Services/Implementations/RabbitMqEventListenerService.cs similarity index 91% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqEventListenerService.cs rename to src/Core/Dirt/Services/Implementations/RabbitMqEventListenerService.cs index 430540a2f7..ca7cd5ef16 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqEventListenerService.cs +++ b/src/Core/Dirt/Services/Implementations/RabbitMqEventListenerService.cs @@ -1,10 +1,10 @@ using System.Text; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Microsoft.Extensions.Logging; using RabbitMQ.Client; using RabbitMQ.Client.Events; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class RabbitMqEventListenerService : EventLoggingListenerService where TConfiguration : IEventListenerConfiguration @@ -69,6 +69,6 @@ public class RabbitMqEventListenerService : EventLoggingListener private static ILogger CreateLogger(ILoggerFactory loggerFactory, TConfiguration configuration) { return loggerFactory.CreateLogger( - categoryName: $"Bit.Core.Services.RabbitMqEventListenerService.{configuration.EventQueueName}"); + categoryName: $"Bit.Core.Dirt.Services.Implementations.RabbitMqEventListenerService.{configuration.EventQueueName}"); } } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqIntegrationListenerService.cs b/src/Core/Dirt/Services/Implementations/RabbitMqIntegrationListenerService.cs similarity index 80% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqIntegrationListenerService.cs rename to src/Core/Dirt/Services/Implementations/RabbitMqIntegrationListenerService.cs index b426032c92..eced9131bb 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqIntegrationListenerService.cs +++ b/src/Core/Dirt/Services/Implementations/RabbitMqIntegrationListenerService.cs @@ -1,12 +1,12 @@ using System.Text; using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using RabbitMQ.Client; using RabbitMQ.Client.Events; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class RabbitMqIntegrationListenerService : BackgroundService where TConfiguration : IIntegrationListenerConfiguration @@ -37,7 +37,7 @@ public class RabbitMqIntegrationListenerService : BackgroundServ _timeProvider = timeProvider; _lazyChannel = new Lazy>(() => _rabbitMqService.CreateChannelAsync()); _logger = loggerFactory.CreateLogger( - categoryName: $"Bit.Core.Services.RabbitMqIntegrationListenerService.{configuration.IntegrationQueueName}"); ; + categoryName: $"Bit.Core.Dirt.Services.Implementations.RabbitMqIntegrationListenerService.{configuration.IntegrationQueueName}"); ; } public override async Task StartAsync(CancellationToken cancellationToken) @@ -106,14 +106,32 @@ public class RabbitMqIntegrationListenerService : BackgroundServ { // Exceeded the max number of retries; fail and send to dead letter queue await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken); - _logger.LogWarning("Max retry attempts reached. Sent to DLQ."); + _logger.LogWarning( + "Integration failure - max retries exceeded. " + + "MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " + + "FailureCategory: {Category}, Reason: {Reason}, RetryCount: {RetryCount}, MaxRetries: {MaxRetries}", + message.MessageId, + message.IntegrationType, + message.OrganizationId, + result.Category, + result.FailureReason, + message.RetryCount, + _maxRetries); } } else { // Fatal error (i.e. not retryable) occurred. Send message to dead letter queue without any retries await _rabbitMqService.PublishToDeadLetterAsync(channel, message, cancellationToken); - _logger.LogWarning("Non-retryable failure. Sent to DLQ."); + _logger.LogWarning( + "Integration failure - non-retryable. " + + "MessageId: {MessageId}, IntegrationType: {IntegrationType}, OrganizationId: {OrgId}, " + + "FailureCategory: {Category}, Reason: {Reason}", + message.MessageId, + message.IntegrationType, + message.OrganizationId, + result.Category, + result.FailureReason); } // Message has been sent to retry or dead letter queues. diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs b/src/Core/Dirt/Services/Implementations/RabbitMqService.cs similarity index 97% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs rename to src/Core/Dirt/Services/Implementations/RabbitMqService.cs index 3e20e34200..c27fb37d08 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs +++ b/src/Core/Dirt/Services/Implementations/RabbitMqService.cs @@ -1,11 +1,11 @@ using System.Text; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Settings; using RabbitMQ.Client; using RabbitMQ.Client.Events; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class RabbitMqService : IRabbitMqService { @@ -122,7 +122,7 @@ public class RabbitMqService : IRabbitMqService body: body); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { await using var channel = await CreateChannelAsync(); var properties = new BasicProperties diff --git a/src/Core/Services/Implementations/RepositoryEventWriteService.cs b/src/Core/Dirt/Services/Implementations/RepositoryEventWriteService.cs similarity index 100% rename from src/Core/Services/Implementations/RepositoryEventWriteService.cs rename to src/Core/Dirt/Services/Implementations/RepositoryEventWriteService.cs diff --git a/src/Core/Dirt/Services/Implementations/SlackIntegrationHandler.cs b/src/Core/Dirt/Services/Implementations/SlackIntegrationHandler.cs new file mode 100644 index 0000000000..6c6a4dd356 --- /dev/null +++ b/src/Core/Dirt/Services/Implementations/SlackIntegrationHandler.cs @@ -0,0 +1,76 @@ +using Bit.Core.Dirt.Models.Data.EventIntegrations; + +namespace Bit.Core.Dirt.Services.Implementations; + +public class SlackIntegrationHandler( + ISlackService slackService) + : IntegrationHandlerBase +{ + public override async Task HandleAsync(IntegrationMessage message) + { + var slackResponse = await slackService.SendSlackMessageByChannelIdAsync( + message.Configuration.Token, + message.RenderedTemplate, + message.Configuration.ChannelId + ); + + if (slackResponse is null) + { + return IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.TransientError, + "Slack response was null" + ); + } + + if (slackResponse.Ok) + { + return IntegrationHandlerResult.Succeed(message); + } + + var category = ClassifySlackError(slackResponse.Error); + return IntegrationHandlerResult.Fail( + message, + category, + slackResponse.Error + ); + } + + /// + /// Classifies a Slack API error code string as an to drive + /// retry behavior and operator-facing failure reporting. + /// + /// + /// + /// Slack responses commonly return an error string when ok is false. This method maps + /// known Slack error codes to failure categories. + /// + /// + /// Any unrecognized error codes default to to avoid + /// incorrectly marking new/unknown Slack failures as non-retryable. + /// + /// + /// The Slack error code string (e.g. invalid_auth, rate_limited). + /// The corresponding . + private static IntegrationFailureCategory ClassifySlackError(string error) + { + return error switch + { + "invalid_auth" => IntegrationFailureCategory.AuthenticationFailed, + "access_denied" => IntegrationFailureCategory.AuthenticationFailed, + "token_expired" => IntegrationFailureCategory.AuthenticationFailed, + "token_revoked" => IntegrationFailureCategory.AuthenticationFailed, + "account_inactive" => IntegrationFailureCategory.AuthenticationFailed, + "not_authed" => IntegrationFailureCategory.AuthenticationFailed, + "channel_not_found" => IntegrationFailureCategory.ConfigurationError, + "is_archived" => IntegrationFailureCategory.ConfigurationError, + "rate_limited" => IntegrationFailureCategory.RateLimited, + "ratelimited" => IntegrationFailureCategory.RateLimited, + "message_limit_exceeded" => IntegrationFailureCategory.RateLimited, + "internal_error" => IntegrationFailureCategory.TransientError, + "service_unavailable" => IntegrationFailureCategory.ServiceUnavailable, + "fatal_error" => IntegrationFailureCategory.ServiceUnavailable, + _ => IntegrationFailureCategory.TransientError + }; + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs b/src/Core/Dirt/Services/Implementations/SlackService.cs similarity index 78% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs rename to src/Core/Dirt/Services/Implementations/SlackService.cs index 4fb74f1f44..7683f718b5 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs +++ b/src/Core/Dirt/Services/Implementations/SlackService.cs @@ -1,11 +1,12 @@ using System.Net.Http.Headers; using System.Net.Http.Json; +using System.Text.Json; using System.Web; -using Bit.Core.Models.Slack; +using Bit.Core.Dirt.Models.Data.Slack; using Bit.Core.Settings; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class SlackService( IHttpClientFactory httpClientFactory, @@ -71,7 +72,7 @@ public class SlackService( public async Task GetDmChannelByEmailAsync(string token, string email) { var userId = await GetUserIdByEmailAsync(token, email); - return await OpenDmChannel(token, userId); + return await OpenDmChannelAsync(token, userId); } public string GetRedirectUrl(string callbackUrl, string state) @@ -90,22 +91,28 @@ public class SlackService( public async Task ObtainTokenViaOAuth(string code, string redirectUrl) { + if (string.IsNullOrEmpty(code) || string.IsNullOrWhiteSpace(redirectUrl)) + { + logger.LogError("Error obtaining token via OAuth: Code and/or RedirectUrl were empty"); + return string.Empty; + } + var tokenResponse = await _httpClient.PostAsync($"{_slackApiBaseUrl}/oauth.v2.access", - new FormUrlEncodedContent(new[] - { + new FormUrlEncodedContent([ new KeyValuePair("client_id", _clientId), new KeyValuePair("client_secret", _clientSecret), new KeyValuePair("code", code), new KeyValuePair("redirect_uri", redirectUrl) - })); + ])); SlackOAuthResponse? result; try { result = await tokenResponse.Content.ReadFromJsonAsync(); } - catch + catch (JsonException ex) { + logger.LogError(ex, "Error parsing SlackOAuthResponse: invalid JSON"); result = null; } @@ -123,14 +130,25 @@ public class SlackService( return result.AccessToken; } - public async Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId) + public async Task SendSlackMessageByChannelIdAsync(string token, string message, + string channelId) { var payload = JsonContent.Create(new { channel = channelId, text = message }); var request = new HttpRequestMessage(HttpMethod.Post, $"{_slackApiBaseUrl}/chat.postMessage"); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); request.Content = payload; - await _httpClient.SendAsync(request); + var response = await _httpClient.SendAsync(request); + + try + { + return await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing Slack message response: invalid JSON"); + return null; + } } private async Task GetUserIdByEmailAsync(string token, string email) @@ -138,7 +156,16 @@ public class SlackService( var request = new HttpRequestMessage(HttpMethod.Get, $"{_slackApiBaseUrl}/users.lookupByEmail?email={email}"); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); var response = await _httpClient.SendAsync(request); - var result = await response.Content.ReadFromJsonAsync(); + SlackUserResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing SlackUserResponse: invalid JSON"); + result = null; + } if (result is null) { @@ -154,7 +181,7 @@ public class SlackService( return result.User.Id; } - private async Task OpenDmChannel(string token, string userId) + private async Task OpenDmChannelAsync(string token, string userId) { if (string.IsNullOrEmpty(userId)) return string.Empty; @@ -164,7 +191,16 @@ public class SlackService( request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); request.Content = payload; var response = await _httpClient.SendAsync(request); - var result = await response.Content.ReadFromJsonAsync(); + SlackDmResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing SlackDmResponse: invalid JSON"); + result = null; + } if (result is null) { diff --git a/src/Core/Dirt/Services/Implementations/TeamsIntegrationHandler.cs b/src/Core/Dirt/Services/Implementations/TeamsIntegrationHandler.cs new file mode 100644 index 0000000000..7aaed6c647 --- /dev/null +++ b/src/Core/Dirt/Services/Implementations/TeamsIntegrationHandler.cs @@ -0,0 +1,66 @@ +using System.Text.Json; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Microsoft.Rest; + +namespace Bit.Core.Dirt.Services.Implementations; + +public class TeamsIntegrationHandler( + ITeamsService teamsService) + : IntegrationHandlerBase +{ + public override async Task HandleAsync( + IntegrationMessage message) + { + try + { + await teamsService.SendMessageToChannelAsync( + serviceUri: message.Configuration.ServiceUrl, + message: message.RenderedTemplate, + channelId: message.Configuration.ChannelId + ); + + return IntegrationHandlerResult.Succeed(message); + } + catch (HttpOperationException ex) + { + var category = ClassifyHttpStatusCode(ex.Response.StatusCode); + return IntegrationHandlerResult.Fail( + message, + category, + ex.Message + ); + } + catch (ArgumentException ex) + { + return IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.ConfigurationError, + ex.Message + ); + } + catch (UriFormatException ex) + { + return IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.ConfigurationError, + ex.Message + ); + } + catch (JsonException ex) + { + return IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.PermanentFailure, + ex.Message + ); + } + catch (Exception ex) + { + return IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.TransientError, + ex.Message + ); + } + } +} diff --git a/src/Core/Dirt/Services/Implementations/TeamsService.cs b/src/Core/Dirt/Services/Implementations/TeamsService.cs new file mode 100644 index 0000000000..edb43bf85e --- /dev/null +++ b/src/Core/Dirt/Services/Implementations/TeamsService.cs @@ -0,0 +1,182 @@ +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text.Json; +using System.Web; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Teams; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Settings; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Teams; +using Microsoft.Bot.Connector; +using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Schema; +using Microsoft.Extensions.Logging; +using TeamInfo = Bit.Core.Dirt.Models.Data.Teams.TeamInfo; + +namespace Bit.Core.Dirt.Services.Implementations; + +public class TeamsService( + IHttpClientFactory httpClientFactory, + IOrganizationIntegrationRepository integrationRepository, + GlobalSettings globalSettings, + ILogger logger) : ActivityHandler, ITeamsService +{ + private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName); + private readonly string _clientId = globalSettings.Teams.ClientId; + private readonly string _clientSecret = globalSettings.Teams.ClientSecret; + private readonly string _scopes = globalSettings.Teams.Scopes; + private readonly string _graphBaseUrl = globalSettings.Teams.GraphBaseUrl; + private readonly string _loginBaseUrl = globalSettings.Teams.LoginBaseUrl; + + public const string HttpClientName = "TeamsServiceHttpClient"; + + public string GetRedirectUrl(string redirectUrl, string state) + { + var query = HttpUtility.ParseQueryString(string.Empty); + query["client_id"] = _clientId; + query["response_type"] = "code"; + query["redirect_uri"] = redirectUrl; + query["response_mode"] = "query"; + query["scope"] = string.Join(" ", _scopes); + query["state"] = state; + + return $"{_loginBaseUrl}/common/oauth2/v2.0/authorize?{query}"; + } + + public async Task ObtainTokenViaOAuth(string code, string redirectUrl) + { + if (string.IsNullOrEmpty(code) || string.IsNullOrWhiteSpace(redirectUrl)) + { + logger.LogError("Error obtaining token via OAuth: Code and/or RedirectUrl were empty"); + return string.Empty; + } + + var request = new HttpRequestMessage(HttpMethod.Post, + $"{_loginBaseUrl}/common/oauth2/v2.0/token"); + + request.Content = new FormUrlEncodedContent(new Dictionary + { + { "client_id", _clientId }, + { "client_secret", _clientSecret }, + { "code", code }, + { "redirect_uri", redirectUrl }, + { "grant_type", "authorization_code" } + }); + + using var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + var errorText = await response.Content.ReadAsStringAsync(); + logger.LogError("Teams OAuth token exchange failed: {errorText}", errorText); + return string.Empty; + } + + TeamsOAuthResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch + { + result = null; + } + + if (result is null) + { + logger.LogError("Error obtaining token via OAuth: Unknown error"); + return string.Empty; + } + + return result.AccessToken; + } + + public async Task> GetJoinedTeamsAsync(string accessToken) + { + using var request = new HttpRequestMessage( + HttpMethod.Get, + $"{_graphBaseUrl}/me/joinedTeams"); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); + + using var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + var errorText = await response.Content.ReadAsStringAsync(); + logger.LogError("Get Teams request failed: {errorText}", errorText); + return new List(); + } + + var result = await response.Content.ReadFromJsonAsync(); + + return result?.Value ?? []; + } + + public async Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message) + { + var credentials = new MicrosoftAppCredentials(_clientId, _clientSecret); + using var connectorClient = new ConnectorClient(serviceUri, credentials); + + var activity = new Activity + { + Type = ActivityTypes.Message, + Text = message + }; + + await connectorClient.Conversations.SendToConversationAsync(channelId, activity); + } + + protected override async Task OnInstallationUpdateAddAsync(ITurnContext turnContext, + CancellationToken cancellationToken) + { + var conversationId = turnContext.Activity.Conversation.Id; + var serviceUrl = turnContext.Activity.ServiceUrl; + var teamId = turnContext.Activity.TeamsGetTeamInfo().AadGroupId; + var tenantId = turnContext.Activity.Conversation.TenantId; + + if (!string.IsNullOrWhiteSpace(conversationId) && + !string.IsNullOrWhiteSpace(serviceUrl) && + Uri.TryCreate(serviceUrl, UriKind.Absolute, out var parsedUri) && + !string.IsNullOrWhiteSpace(teamId) && + !string.IsNullOrWhiteSpace(tenantId)) + { + await HandleIncomingAppInstallAsync( + conversationId: conversationId, + serviceUrl: parsedUri, + teamId: teamId, + tenantId: tenantId + ); + } + + await base.OnInstallationUpdateAddAsync(turnContext, cancellationToken); + } + + internal async Task HandleIncomingAppInstallAsync( + string conversationId, + Uri serviceUrl, + string teamId, + string tenantId) + { + var integration = await integrationRepository.GetByTeamsConfigurationTenantIdTeamId( + tenantId: tenantId, + teamId: teamId); + + if (integration?.Configuration is null) + { + return; + } + + var teamsConfig = JsonSerializer.Deserialize(integration.Configuration); + if (teamsConfig is null || teamsConfig.IsCompleted) + { + return; + } + + integration.Configuration = JsonSerializer.Serialize(teamsConfig with + { + ChannelId = conversationId, + ServiceUrl = serviceUrl + }); + + await integrationRepository.UpsertAsync(integration); + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/WebhookIntegrationHandler.cs b/src/Core/Dirt/Services/Implementations/WebhookIntegrationHandler.cs similarity index 92% rename from src/Core/AdminConsole/Services/Implementations/EventIntegrations/WebhookIntegrationHandler.cs rename to src/Core/Dirt/Services/Implementations/WebhookIntegrationHandler.cs index 0599f6e9d4..6caa1b9a6e 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/WebhookIntegrationHandler.cs +++ b/src/Core/Dirt/Services/Implementations/WebhookIntegrationHandler.cs @@ -1,8 +1,8 @@ using System.Net.Http.Headers; using System.Text; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; -namespace Bit.Core.Services; +namespace Bit.Core.Dirt.Services.Implementations; public class WebhookIntegrationHandler( IHttpClientFactory httpClientFactory, diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs b/src/Core/Dirt/Services/NoopImplementations/NoopEventService.cs similarity index 100% rename from src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs rename to src/Core/Dirt/Services/NoopImplementations/NoopEventService.cs diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventWriteService.cs b/src/Core/Dirt/Services/NoopImplementations/NoopEventWriteService.cs similarity index 100% rename from src/Core/AdminConsole/Services/NoopImplementations/NoopEventWriteService.cs rename to src/Core/Dirt/Services/NoopImplementations/NoopEventWriteService.cs diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs b/src/Core/Dirt/Services/NoopImplementations/NoopSlackService.cs similarity index 71% rename from src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs rename to src/Core/Dirt/Services/NoopImplementations/NoopSlackService.cs index d6c8d08c4c..30b68186bc 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs +++ b/src/Core/Dirt/Services/NoopImplementations/NoopSlackService.cs @@ -1,6 +1,6 @@ -using Bit.Core.Services; +using Bit.Core.Dirt.Models.Data.Slack; -namespace Bit.Core.AdminConsole.Services.NoopImplementations; +namespace Bit.Core.Dirt.Services.NoopImplementations; public class NoopSlackService : ISlackService { @@ -24,9 +24,10 @@ public class NoopSlackService : ISlackService return string.Empty; } - public Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId) + public Task SendSlackMessageByChannelIdAsync(string token, string message, + string channelId) { - return Task.FromResult(0); + return Task.FromResult(null); } public Task ObtainTokenViaOAuth(string code, string redirectUrl) diff --git a/src/Core/Dirt/Services/NoopImplementations/NoopTeamsService.cs b/src/Core/Dirt/Services/NoopImplementations/NoopTeamsService.cs new file mode 100644 index 0000000000..3ebd58d996 --- /dev/null +++ b/src/Core/Dirt/Services/NoopImplementations/NoopTeamsService.cs @@ -0,0 +1,26 @@ +using Bit.Core.Dirt.Models.Data.Teams; + +namespace Bit.Core.Dirt.Services.NoopImplementations; + +public class NoopTeamsService : ITeamsService +{ + public string GetRedirectUrl(string callbackUrl, string state) + { + return string.Empty; + } + + public Task ObtainTokenViaOAuth(string code, string redirectUrl) + { + return Task.FromResult(string.Empty); + } + + public Task> GetJoinedTeamsAsync(string accessToken) + { + return Task.FromResult>(Array.Empty()); + } + + public Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message) + { + return Task.CompletedTask; + } +} diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index 12c527ed78..669e32bcbe 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -3,6 +3,7 @@ using System.Text.Json; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; @@ -21,6 +22,9 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac [MaxLength(256)] public string Email { get; set; } = null!; public bool EmailVerified { get; set; } + /// + /// The server-side master-password hash + /// [MaxLength(300)] public string? MasterPassword { get; set; } [MaxLength(50)] @@ -41,9 +45,35 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac /// organization membership. ///
    public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; + /// + /// The master-password-sealed user key. + /// public string? Key { get; set; } + /// + /// The raw public key, without a signature from the user's signature key. + /// public string? PublicKey { get; set; } + /// + /// User key wrapped private key. + /// public string? PrivateKey { get; set; } + /// + /// The public key, signed by the user's signature key. + /// + public string? SignedPublicKey { get; set; } + /// + /// The security version is included in the security state, but needs COSE parsing + /// + public int? SecurityVersion { get; set; } + /// + /// The security state is a signed object attesting to the version of the user's account. + /// + public string? SecurityState { get; set; } + /// + /// Indicates whether the user has a personal premium subscription. + /// Does not include premium access from organizations - + /// do not use this to check whether the user can access premium features. + /// public bool Premium { get; set; } public DateTime? PremiumExpirationDate { get; set; } public DateTime? RenewalReminderDate { get; set; } @@ -175,9 +205,10 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac return Id; } - public bool GetPremium() + public int GetSecurityVersion() { - return Premium; + // If no security version is set, it is version 1. The minimum initialized version is 2. + return SecurityVersion ?? 1; } /// @@ -243,4 +274,14 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac { return MasterPassword != null; } + + public PublicKeyEncryptionKeyPairData GetPublicKeyEncryptionKeyPair() + { + if (string.IsNullOrWhiteSpace(PrivateKey) || string.IsNullOrWhiteSpace(PublicKey)) + { + throw new InvalidOperationException("User public key encryption key pair is not fully initialized."); + } + + return new PublicKeyEncryptionKeyPairData(PrivateKey, PublicKey, SignedPublicKey); + } } diff --git a/src/Core/Enums/PushNotificationLogOutReason.cs b/src/Core/Enums/PushNotificationLogOutReason.cs new file mode 100644 index 0000000000..a24f790305 --- /dev/null +++ b/src/Core/Enums/PushNotificationLogOutReason.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Enums; + +public enum PushNotificationLogOutReason : byte +{ + KdfChange = 0 +} diff --git a/src/Core/Jobs/BaseJobsHostedService.cs b/src/Core/Jobs/BaseJobsHostedService.cs index 3e7bce7e0f..8b74052f8f 100644 --- a/src/Core/Jobs/BaseJobsHostedService.cs +++ b/src/Core/Jobs/BaseJobsHostedService.cs @@ -107,7 +107,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable throw new Exception("Job failed to start after 10 retries."); } - _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); + _logger.LogWarning(e, "Exception while trying to schedule job: {JobName}", job.FullName); var random = new Random(); await Task.Delay(random.Next(50, 250)); } @@ -125,7 +125,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable continue; } - _logger.LogInformation($"Deleting old job with key {key}"); + _logger.LogInformation("Deleting old job with key {Key}", key); await _scheduler.DeleteJob(key); } @@ -138,7 +138,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable continue; } - _logger.LogInformation($"Unscheduling old trigger with key {key}"); + _logger.LogInformation("Unscheduling old trigger with key {Key}", key); await _scheduler.UnscheduleJob(key); } } diff --git a/src/Core/KeyManagement/Authorization/KeyConnectorAuthorizationHandler.cs b/src/Core/KeyManagement/Authorization/KeyConnectorAuthorizationHandler.cs new file mode 100644 index 0000000000..7937390a8c --- /dev/null +++ b/src/Core/KeyManagement/Authorization/KeyConnectorAuthorizationHandler.cs @@ -0,0 +1,52 @@ +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Microsoft.AspNetCore.Authorization; + +namespace Bit.Core.KeyManagement.Authorization; + +public class KeyConnectorAuthorizationHandler : AuthorizationHandler +{ + private readonly ICurrentContext _currentContext; + + public KeyConnectorAuthorizationHandler(ICurrentContext currentContext) + { + _currentContext = currentContext; + } + + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, + KeyConnectorOperationsRequirement requirement, + User user) + { + var authorized = requirement switch + { + not null when requirement == KeyConnectorOperations.Use => CanUse(user), + _ => throw new ArgumentException("Unsupported operation requirement type provided.", nameof(requirement)) + }; + + if (authorized) + { + context.Succeed(requirement); + } + + return Task.CompletedTask; + } + + private bool CanUse(User user) + { + // User cannot use Key Connector if they already use it + if (user.UsesKeyConnector) + { + return false; + } + + // User cannot use Key Connector if they are an owner or admin of any organization + if (_currentContext.Organizations.Any(u => + u.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)) + { + return false; + } + + return true; + } +} diff --git a/src/Core/KeyManagement/Authorization/KeyConnectorOperations.cs b/src/Core/KeyManagement/Authorization/KeyConnectorOperations.cs new file mode 100644 index 0000000000..a8d09a6ac7 --- /dev/null +++ b/src/Core/KeyManagement/Authorization/KeyConnectorOperations.cs @@ -0,0 +1,16 @@ +using Microsoft.AspNetCore.Authorization.Infrastructure; + +namespace Bit.Core.KeyManagement.Authorization; + +public class KeyConnectorOperationsRequirement : OperationAuthorizationRequirement +{ + public KeyConnectorOperationsRequirement(string name) + { + Name = name; + } +} + +public static class KeyConnectorOperations +{ + public static readonly KeyConnectorOperationsRequirement Use = new(nameof(Use)); +} diff --git a/src/Core/KeyManagement/Commands/Interfaces/ISetKeyConnectorKeyCommand.cs b/src/Core/KeyManagement/Commands/Interfaces/ISetKeyConnectorKeyCommand.cs new file mode 100644 index 0000000000..65f6cddeb5 --- /dev/null +++ b/src/Core/KeyManagement/Commands/Interfaces/ISetKeyConnectorKeyCommand.cs @@ -0,0 +1,13 @@ +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Commands.Interfaces; + +/// +/// Creates the user key and account cryptographic state for a new user registering +/// with Key Connector SSO configuration. +/// +public interface ISetKeyConnectorKeyCommand +{ + Task SetKeyConnectorKeyForUserAsync(User user, KeyConnectorKeysData keyConnectorKeysData); +} diff --git a/src/Core/KeyManagement/Commands/SetKeyConnectorKeyCommand.cs b/src/Core/KeyManagement/Commands/SetKeyConnectorKeyCommand.cs new file mode 100644 index 0000000000..a96042de30 --- /dev/null +++ b/src/Core/KeyManagement/Commands/SetKeyConnectorKeyCommand.cs @@ -0,0 +1,60 @@ +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Authorization; +using Bit.Core.KeyManagement.Commands.Interfaces; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; + +namespace Bit.Core.KeyManagement.Commands; + +public class SetKeyConnectorKeyCommand : ISetKeyConnectorKeyCommand +{ + private readonly IAuthorizationService _authorizationService; + private readonly ICurrentContext _currentContext; + private readonly IEventService _eventService; + private readonly IAcceptOrgUserCommand _acceptOrgUserCommand; + private readonly IUserService _userService; + private readonly IUserRepository _userRepository; + + public SetKeyConnectorKeyCommand( + IAuthorizationService authorizationService, + ICurrentContext currentContext, + IEventService eventService, + IAcceptOrgUserCommand acceptOrgUserCommand, + IUserService userService, + IUserRepository userRepository) + { + _authorizationService = authorizationService; + _currentContext = currentContext; + _eventService = eventService; + _acceptOrgUserCommand = acceptOrgUserCommand; + _userService = userService; + _userRepository = userRepository; + } + + public async Task SetKeyConnectorKeyForUserAsync(User user, KeyConnectorKeysData keyConnectorKeysData) + { + var authorizationResult = await _authorizationService.AuthorizeAsync(_currentContext.HttpContext.User, user, + KeyConnectorOperations.Use); + if (!authorizationResult.Succeeded) + { + throw new BadRequestException("Cannot use Key Connector"); + } + + var setKeyConnectorUserKeyTask = + _userRepository.SetKeyConnectorUserKey(user.Id, keyConnectorKeysData.KeyConnectorKeyWrappedUserKey); + + await _userRepository.SetV2AccountCryptographicStateAsync(user.Id, + keyConnectorKeysData.AccountKeys.ToAccountKeysData(), [setKeyConnectorUserKeyTask]); + + await _eventService.LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + await _acceptOrgUserCommand.AcceptOrgUserByOrgSsoIdAsync(keyConnectorKeysData.OrgIdentifier, user, + _userService); + } +} diff --git a/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs b/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs new file mode 100644 index 0000000000..dada9e0d7a --- /dev/null +++ b/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs @@ -0,0 +1,30 @@ +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Enums; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + + +namespace Bit.Core.KeyManagement.Entities; + +public class UserSignatureKeyPair : ITableObject, IRevisable +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public SignatureAlgorithm SignatureAlgorithm { get; set; } + + public required string VerifyingKey { get; set; } + public required string SigningKey { get; set; } + + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } + + public SignatureKeyPairData ToSignatureKeyPairData() + { + return new SignatureKeyPairData(SignatureAlgorithm, SigningKey, VerifyingKey); + } +} diff --git a/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs b/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs new file mode 100644 index 0000000000..9216c3f489 --- /dev/null +++ b/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs @@ -0,0 +1,9 @@ +namespace Bit.Core.KeyManagement.Enums; + +// +// Represents the algorithm / digital signature scheme used for a signature key pair. +// +public enum SignatureAlgorithm : byte +{ + Ed25519 = 0 +} diff --git a/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs b/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs index fe736f9ac6..83e47c4931 100644 --- a/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs +++ b/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Platform.Push; @@ -18,17 +19,22 @@ public class ChangeKdfCommand : IChangeKdfCommand private readonly IUserRepository _userRepository; private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly ILogger _logger; + private readonly IFeatureService _featureService; - public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, IUserRepository userRepository, IdentityErrorDescriber describer, ILogger logger) + public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, + IUserRepository userRepository, IdentityErrorDescriber describer, ILogger logger, + IFeatureService featureService) { _userService = userService; _pushService = pushService; _userRepository = userRepository; _identityErrorDescriber = describer; _logger = logger; + _featureService = featureService; } - public async Task ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData) + public async Task ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, + MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData) { ArgumentNullException.ThrowIfNull(user); if (!await _userService.CheckPasswordAsync(user, masterPasswordAuthenticationHash)) @@ -37,8 +43,8 @@ public class ChangeKdfCommand : IChangeKdfCommand } // Validate to prevent user account from becoming un-decryptable from invalid parameters - // - // Prevent a de-synced salt value from creating an un-decryptable unlock method + // + // Prevent a de-synced salt value from creating an un-decryptable unlock method authenticationData.ValidateSaltUnchangedForUser(user); unlockData.ValidateSaltUnchangedForUser(user); @@ -47,12 +53,15 @@ public class ChangeKdfCommand : IChangeKdfCommand { throw new BadRequestException("KDF settings must be equal for authentication and unlock."); } + var validationErrors = KdfSettingsValidator.Validate(unlockData.Kdf); if (validationErrors.Any()) { throw new BadRequestException("KDF settings are invalid."); } + var logoutOnKdfChange = !_featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange); + // Update the user with the new KDF settings // This updates the authentication data and unlock data for the user separately. Currently these still // use shared values for KDF settings and salt. @@ -68,7 +77,8 @@ public class ChangeKdfCommand : IChangeKdfCommand // This entire operation MUST be atomic to prevent a user from being locked out of their account. // Salt is ensured to be the same as unlock data, and the value stored in the account and not updated. // KDF is ensured to be the same as unlock data above and updated below. - var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash); + var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash, + refreshStamp: logoutOnKdfChange); if (!result.Succeeded) { _logger.LogWarning("Change KDF failed for user {userId}.", user.Id); @@ -88,7 +98,17 @@ public class ChangeKdfCommand : IChangeKdfCommand user.LastKdfChangeDate = now; await _userRepository.ReplaceAsync(user); - await _pushService.PushLogOutAsync(user.Id); + if (logoutOnKdfChange) + { + await _pushService.PushLogOutAsync(user.Id); + } + else + { + // Clients that support the new feature flag will ignore the logout when it matches the reason and the feature flag is enabled. + await _pushService.PushLogOutAsync(user.Id, reason: PushNotificationLogOutReason.KdfChange); + await _pushService.PushSyncSettingsAsync(user.Id); + } + return IdentityResult.Success; } } diff --git a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs index e4ebdb4860..96f990c299 100644 --- a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs +++ b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs @@ -1,7 +1,11 @@ -using Bit.Core.KeyManagement.Commands; +using Bit.Core.KeyManagement.Authorization; +using Bit.Core.KeyManagement.Commands; using Bit.Core.KeyManagement.Commands.Interfaces; using Bit.Core.KeyManagement.Kdf; using Bit.Core.KeyManagement.Kdf.Implementations; +using Bit.Core.KeyManagement.Queries; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection; namespace Bit.Core.KeyManagement; @@ -10,13 +14,27 @@ public static class KeyManagementServiceCollectionExtensions { public static void AddKeyManagementServices(this IServiceCollection services) { + services.AddKeyManagementAuthorizationHandlers(); services.AddKeyManagementCommands(); + services.AddKeyManagementQueries(); services.AddSendPasswordServices(); } + private static void AddKeyManagementAuthorizationHandlers(this IServiceCollection services) + { + services.AddScoped(); + } + private static void AddKeyManagementCommands(this IServiceCollection services) { services.AddScoped(); services.AddScoped(); + services.AddScoped(); + } + + private static void AddKeyManagementQueries(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); } } diff --git a/src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs new file mode 100644 index 0000000000..bdf538e6d8 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Request/AccountKeysRequestModel.cs @@ -0,0 +1,50 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Core.KeyManagement.Models.Api.Request; + +public class AccountKeysRequestModel +{ + [EncryptedString] public required string UserKeyEncryptedAccountPrivateKey { get; set; } + public required string AccountPublicKey { get; set; } + + public PublicKeyEncryptionKeyPairRequestModel? PublicKeyEncryptionKeyPair { get; set; } + public SignatureKeyPairRequestModel? SignatureKeyPair { get; set; } + public SecurityStateModel? SecurityState { get; set; } + + public UserAccountKeysData ToAccountKeysData() + { + // This will be cleaned up, after a compatibility period, at which point PublicKeyEncryptionKeyPair and SignatureKeyPair will be required. + // TODO: https://bitwarden.atlassian.net/browse/PM-23751 + if (PublicKeyEncryptionKeyPair == null) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData + ( + UserKeyEncryptedAccountPrivateKey, + AccountPublicKey + ), + }; + } + else + { + if (SignatureKeyPair == null || SecurityState == null) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = PublicKeyEncryptionKeyPair.ToPublicKeyEncryptionKeyPairData(), + }; + } + else + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = PublicKeyEncryptionKeyPair.ToPublicKeyEncryptionKeyPairData(), + SignatureKeyPairData = SignatureKeyPair.ToSignatureKeyPairData(), + SecurityStateData = SecurityState.ToSecurityState() + }; + } + } + } +} diff --git a/src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs new file mode 100644 index 0000000000..f9b009f7e2 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Request/PublicKeyEncryptionKeyPairRequestModel.cs @@ -0,0 +1,20 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Core.KeyManagement.Models.Api.Request; + +public class PublicKeyEncryptionKeyPairRequestModel +{ + [EncryptedString] public required string WrappedPrivateKey { get; set; } + public required string PublicKey { get; set; } + public string? SignedPublicKey { get; set; } + + public PublicKeyEncryptionKeyPairData ToPublicKeyEncryptionKeyPairData() + { + return new PublicKeyEncryptionKeyPairData( + WrappedPrivateKey, + PublicKey, + SignedPublicKey + ); + } +} diff --git a/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs b/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs new file mode 100644 index 0000000000..1acb52146e --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs @@ -0,0 +1,32 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Models.Api.Request; + +public class SecurityStateModel +{ + [StringLength(1000)] + [JsonPropertyName("securityState")] + public required string SecurityState { get; set; } + [JsonPropertyName("securityVersion")] + public required int SecurityVersion { get; set; } + + public SecurityStateData ToSecurityState() + { + return new SecurityStateData + { + SecurityState = SecurityState, + SecurityVersion = SecurityVersion + }; + } + + public static SecurityStateModel FromSecurityStateData(SecurityStateData data) + { + return new SecurityStateModel + { + SecurityState = data.SecurityState, + SecurityVersion = data.SecurityVersion + }; + } +} diff --git a/src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs new file mode 100644 index 0000000000..a569bc70ab --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Request/SignatureKeyPairRequestModel.cs @@ -0,0 +1,28 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Core.KeyManagement.Models.Api.Request; + +public class SignatureKeyPairRequestModel +{ + public required string SignatureAlgorithm { get; set; } + [EncryptedString] public required string WrappedSigningKey { get; set; } + public required string VerifyingKey { get; set; } + + public SignatureKeyPairData ToSignatureKeyPairData() + { + if (SignatureAlgorithm != "ed25519") + { + throw new ArgumentException( + $"Unsupported signature algorithm: {SignatureAlgorithm}" + ); + } + var algorithm = Core.KeyManagement.Enums.SignatureAlgorithm.Ed25519; + + return new SignatureKeyPairData( + algorithm, + WrappedSigningKey, + VerifyingKey + ); + } +} diff --git a/src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs similarity index 75% rename from src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs rename to src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs index f7d5dee852..eebed83485 100644 --- a/src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs +++ b/src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs @@ -2,11 +2,15 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.KeyManagement.Models.Response; +namespace Bit.Core.KeyManagement.Models.Api.Response; public class MasterPasswordUnlockResponseModel { public required MasterPasswordUnlockKdfResponseModel Kdf { get; init; } + /// + /// The user's symmetric key encrypted with their master key. + /// Also known as "MasterKeyWrappedUserKey" + /// [EncryptedString] public required string MasterKeyEncryptedUserKey { get; init; } [StringLength(256)] public required string Salt { get; init; } } diff --git a/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs new file mode 100644 index 0000000000..bcee4c0ada --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs @@ -0,0 +1,48 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +/// +/// This response model is used to return the asymmetric encryption keys, +/// and signature keys of an entity. This includes the private keys of the key pairs, +/// (private key, signing key), and the public keys of the key pairs (unsigned public key, +/// signed public key, verification key). +/// +public class PrivateKeysResponseModel : ResponseModel +{ + // Not all accounts have signature keys, but all accounts have public encryption keys. + [JsonPropertyName("signatureKeyPair")] + public SignatureKeyPairResponseModel? SignatureKeyPair { get; set; } + + [JsonPropertyName("publicKeyEncryptionKeyPair")] + public required PublicKeyEncryptionKeyPairResponseModel PublicKeyEncryptionKeyPair { get; set; } + + [JsonPropertyName("securityState")] + public SecurityStateModel? SecurityState { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PrivateKeysResponseModel(UserAccountKeysData accountKeys) : base("privateKeys") + { + ArgumentNullException.ThrowIfNull(accountKeys); + PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairResponseModel(accountKeys.PublicKeyEncryptionKeyPairData); + + if (accountKeys.SignatureKeyPairData != null && accountKeys.SecurityStateData != null) + { + SignatureKeyPair = new SignatureKeyPairResponseModel(accountKeys.SignatureKeyPairData); + SecurityState = SecurityStateModel.FromSecurityStateData(accountKeys.SecurityStateData!); + } + } + + [JsonConstructor] + public PrivateKeysResponseModel(SignatureKeyPairResponseModel? signatureKeyPair, PublicKeyEncryptionKeyPairResponseModel publicKeyEncryptionKeyPair, SecurityStateModel? securityState) + : base("privateKeys") + { + SignatureKeyPair = signatureKeyPair; + PublicKeyEncryptionKeyPair = publicKeyEncryptionKeyPair ?? throw new ArgumentNullException(nameof(publicKeyEncryptionKeyPair)); + SecurityState = securityState; + } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs new file mode 100644 index 0000000000..e5436b6131 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs @@ -0,0 +1,34 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +public class PublicKeyEncryptionKeyPairResponseModel : ResponseModel +{ + [JsonPropertyName("wrappedPrivateKey")] + public required string WrappedPrivateKey { get; set; } + [JsonPropertyName("publicKey")] + public required string PublicKey { get; set; } + [JsonPropertyName("signedPublicKey")] + public string? SignedPublicKey { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeyEncryptionKeyPairResponseModel(PublicKeyEncryptionKeyPairData keyPair) + : base("publicKeyEncryptionKeyPair") + { + WrappedPrivateKey = keyPair.WrappedPrivateKey; + PublicKey = keyPair.PublicKey; + SignedPublicKey = keyPair.SignedPublicKey; + } + + [JsonConstructor] + public PublicKeyEncryptionKeyPairResponseModel(string wrappedPrivateKey, string publicKey, string? signedPublicKey) + : base("publicKeyEncryptionKeyPair") + { + WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey)); + PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey)); + SignedPublicKey = signedPublicKey; + } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs new file mode 100644 index 0000000000..b341a87e3e --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs @@ -0,0 +1,30 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +/// +/// This response model is used to return the public keys of a user, to any other registered user or entity on the server. +/// It can contain public keys (signature/encryption), and proofs between the two. It does not contain (encrypted) private keys. +/// +public class PublicKeysResponseModel : ResponseModel +{ + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeysResponseModel(UserAccountKeysData accountKeys) + : base("publicKeys") + { + ArgumentNullException.ThrowIfNull(accountKeys); + PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + + if (accountKeys.SignatureKeyPairData != null) + { + SignedPublicKey = accountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + VerifyingKey = accountKeys.SignatureKeyPairData.VerifyingKey; + } + } + + public string? VerifyingKey { get; set; } + public string? SignedPublicKey { get; set; } + public required string PublicKey { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs new file mode 100644 index 0000000000..34d51f8bd4 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs @@ -0,0 +1,32 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +public class SignatureKeyPairResponseModel : ResponseModel +{ + [JsonPropertyName("wrappedSigningKey")] + public required string WrappedSigningKey { get; set; } + [JsonPropertyName("verifyingKey")] + public required string VerifyingKey { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public SignatureKeyPairResponseModel(SignatureKeyPairData signatureKeyPair) + : base("signatureKeyPair") + { + ArgumentNullException.ThrowIfNull(signatureKeyPair); + WrappedSigningKey = signatureKeyPair.WrappedSigningKey; + VerifyingKey = signatureKeyPair.VerifyingKey; + } + + + [JsonConstructor] + public SignatureKeyPairResponseModel(string wrappedSigningKey, string verifyingKey) + : base("signatureKeyPair") + { + WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey)); + VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey)); + } +} diff --git a/src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs similarity index 82% rename from src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs rename to src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs index a4d259a00a..536347cea9 100644 --- a/src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs +++ b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.KeyManagement.Models.Response; +namespace Bit.Core.KeyManagement.Models.Api.Response; public class UserDecryptionResponseModel { diff --git a/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs b/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs new file mode 100644 index 0000000000..3821831bad --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/KeyConnectorConfirmationDetails.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.KeyManagement.Models.Data; + +public class KeyConnectorConfirmationDetails +{ + public required string OrganizationName { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Data/KeyConnectorKeysData.cs b/src/Core/KeyManagement/Models/Data/KeyConnectorKeysData.cs new file mode 100644 index 0000000000..5675c6bc96 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/KeyConnectorKeysData.cs @@ -0,0 +1,12 @@ +using Bit.Core.KeyManagement.Models.Api.Request; + +namespace Bit.Core.KeyManagement.Models.Data; + +public class KeyConnectorKeysData +{ + public required string KeyConnectorKeyWrappedUserKey { get; set; } + + public required AccountKeysRequestModel AccountKeys { get; set; } + + public required string OrgIdentifier { get; init; } +} diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs index c0ae949a3f..1bc7006cef 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs @@ -1,4 +1,5 @@ using Bit.Core.Entities; +using Bit.Core.Exceptions; namespace Bit.Core.KeyManagement.Models.Data; @@ -12,7 +13,7 @@ public class MasterPasswordAuthenticationData { if (user.GetMasterPasswordSalt() != Salt) { - throw new ArgumentException("Invalid master password salt."); + throw new BadRequestException("Invalid master password salt."); } } } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs index e305d92fec..ad3a0b692b 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs @@ -13,6 +13,10 @@ public class MasterPasswordUnlockAndAuthenticationData public required string Email { get; set; } public required string MasterKeyAuthenticationHash { get; set; } + /// + /// The user's symmetric key encrypted with their master key. + /// Also known as "MasterKeyWrappedUserKey" + /// public required string MasterKeyEncryptedUserKey { get; set; } public string? MasterPasswordHint { get; set; } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs index d1ab6f645b..cb18ed2a78 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs @@ -1,6 +1,5 @@ -#nullable enable - -using Bit.Core.Entities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; namespace Bit.Core.KeyManagement.Models.Data; @@ -14,7 +13,7 @@ public class MasterPasswordUnlockData { if (user.GetMasterPasswordSalt() != Salt) { - throw new ArgumentException("Invalid master password salt."); + throw new BadRequestException("Invalid master password salt."); } } } diff --git a/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs b/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs new file mode 100644 index 0000000000..fb8b09d390 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs @@ -0,0 +1,20 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.KeyManagement.Models.Data; + + +public class PublicKeyEncryptionKeyPairData +{ + public required string WrappedPrivateKey { get; set; } + public string? SignedPublicKey { get; set; } + public required string PublicKey { get; set; } + + [JsonConstructor] + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeyEncryptionKeyPairData(string wrappedPrivateKey, string publicKey, string? signedPublicKey = null) + { + WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey)); + PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey)); + SignedPublicKey = signedPublicKey; + } +} diff --git a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs index 557fb56ff3..19d14b273f 100644 --- a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs +++ b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs @@ -1,6 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - + using Bit.Core.Auth.Entities; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; @@ -12,21 +10,19 @@ namespace Bit.Core.KeyManagement.Models.Data; public class RotateUserAccountKeysData { // Authentication for this requests - public string OldMasterKeyAuthenticationHash { get; set; } + public required string OldMasterKeyAuthenticationHash { get; set; } - // Other keys encrypted by the userkey - public string UserKeyEncryptedAccountPrivateKey { get; set; } - public string AccountPublicKey { get; set; } + public required UserAccountKeysData AccountKeys { get; set; } // All methods to get to the userkey - public MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; } - public IEnumerable EmergencyAccesses { get; set; } - public IReadOnlyList OrganizationUsers { get; set; } - public IEnumerable WebAuthnKeys { get; set; } - public IEnumerable DeviceKeys { get; set; } + public required MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; } + public required IEnumerable EmergencyAccesses { get; set; } + public required IReadOnlyList OrganizationUsers { get; set; } + public required IEnumerable WebAuthnKeys { get; set; } + public required IEnumerable DeviceKeys { get; set; } // User vault data encrypted by the userkey - public IEnumerable Ciphers { get; set; } - public IEnumerable Folders { get; set; } - public IReadOnlyList Sends { get; set; } + public required IEnumerable Ciphers { get; set; } + public required IEnumerable Folders { get; set; } + public required IReadOnlyList Sends { get; set; } } diff --git a/src/Core/KeyManagement/Models/Data/SecurityStateData.cs b/src/Core/KeyManagement/Models/Data/SecurityStateData.cs new file mode 100644 index 0000000000..c9a4610387 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/SecurityStateData.cs @@ -0,0 +1,10 @@ + +namespace Bit.Core.KeyManagement.Models.Data; + +public class SecurityStateData +{ + public required string SecurityState { get; set; } + // The security version is included in the security state, but needs COSE parsing, + // so this is a separate copy that can be used directly. + public required int SecurityVersion { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs b/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs new file mode 100644 index 0000000000..32ae3eef8f --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs @@ -0,0 +1,21 @@ + +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Enums; + +namespace Bit.Core.KeyManagement.Models.Data; + +public class SignatureKeyPairData +{ + public required SignatureAlgorithm SignatureAlgorithm { get; set; } + public required string WrappedSigningKey { get; set; } + public required string VerifyingKey { get; set; } + + [JsonConstructor] + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public SignatureKeyPairData(SignatureAlgorithm signatureAlgorithm, string wrappedSigningKey, string verifyingKey) + { + SignatureAlgorithm = signatureAlgorithm; + WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey)); + VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey)); + } +} diff --git a/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs new file mode 100644 index 0000000000..3d552a10de --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs @@ -0,0 +1,34 @@ +namespace Bit.Core.KeyManagement.Models.Data; + +/// +/// Represents an expanded account cryptographic state for a user. Expanded here means +/// that it does not only contain the (wrapped) private / signing key, but also the public +/// key / verifying key. The client side only needs a subset of this data to unlock +/// their vault and the public parts can be derived. +/// +public class UserAccountKeysData +{ + public required PublicKeyEncryptionKeyPairData PublicKeyEncryptionKeyPairData { get; set; } + public SignatureKeyPairData? SignatureKeyPairData { get; set; } + public SecurityStateData? SecurityStateData { get; set; } + + /// + /// Checks whether the account cryptographic state is for a V1 encryption user or a V2 encryption user. + /// Throws if the state is invalid + /// + public bool IsV2Encryption() + { + if (PublicKeyEncryptionKeyPairData.SignedPublicKey != null && SignatureKeyPairData != null && SecurityStateData != null) + { + return true; + } + else if (PublicKeyEncryptionKeyPairData.SignedPublicKey == null && SignatureKeyPairData == null && SecurityStateData == null) + { + return false; + } + else + { + throw new InvalidOperationException("Invalid account cryptographic state: V2 encryption fields must be either all present or all absent."); + } + } +} diff --git a/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs b/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs new file mode 100644 index 0000000000..60b78c03f4 --- /dev/null +++ b/src/Core/KeyManagement/Queries/Interfaces/IKeyConnectorConfirmationDetailsQuery.cs @@ -0,0 +1,8 @@ +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Queries.Interfaces; + +public interface IKeyConnectorConfirmationDetailsQuery +{ + public Task Run(string orgSsoIdentifier, Guid userId); +} diff --git a/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs b/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs new file mode 100644 index 0000000000..4ea9b7582b --- /dev/null +++ b/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs @@ -0,0 +1,10 @@ + +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Queries.Interfaces; + +public interface IUserAccountKeysQuery +{ + Task Run(User user); +} diff --git a/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs b/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs new file mode 100644 index 0000000000..0c210e2fd1 --- /dev/null +++ b/src/Core/KeyManagement/Queries/KeyConnectorConfirmationDetailsQuery.cs @@ -0,0 +1,35 @@ +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Repositories; + +namespace Bit.Core.KeyManagement.Queries; + +public class KeyConnectorConfirmationDetailsQuery : IKeyConnectorConfirmationDetailsQuery +{ + private readonly IOrganizationRepository _organizationRepository; + private readonly IOrganizationUserRepository _organizationUserRepository; + + public KeyConnectorConfirmationDetailsQuery(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository) + { + _organizationRepository = organizationRepository; + _organizationUserRepository = organizationUserRepository; + } + + public async Task Run(string orgSsoIdentifier, Guid userId) + { + var org = await _organizationRepository.GetByIdentifierAsync(orgSsoIdentifier); + if (org is not { UseKeyConnector: true }) + { + throw new NotFoundException(); + } + + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(org.Id, userId); + if (orgUser == null) + { + throw new NotFoundException(); + } + + return new KeyConnectorConfirmationDetails { OrganizationName = org.Name, }; + } +} diff --git a/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs b/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs new file mode 100644 index 0000000000..7aafd2cf1e --- /dev/null +++ b/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs @@ -0,0 +1,35 @@ + +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.KeyManagement.Repositories; + +namespace Bit.Core.KeyManagement.Queries; + + +public class UserAccountKeysQuery(IUserSignatureKeyPairRepository signatureKeyPairRepository) : IUserAccountKeysQuery +{ + public async Task Run(User user) + { + if (user.GetSecurityVersion() < 2) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + }; + } + else + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = await signatureKeyPairRepository.GetByUserIdAsync(user.Id), + SecurityStateData = new SecurityStateData + { + SecurityState = user.SecurityState!, + SecurityVersion = user.GetSecurityVersion(), + } + }; + } + } +} diff --git a/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs b/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..ce8979620f --- /dev/null +++ b/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs @@ -0,0 +1,14 @@ + +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Repositories; + +namespace Bit.Core.KeyManagement.Repositories; + +public interface IUserSignatureKeyPairRepository : IRepository +{ + public Task GetByUserIdAsync(Guid userId); + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signatureKeyPair); + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signatureKeyPair); +} diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index 91363abee8..c1e7905d78 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -1,6 +1,11 @@ -using Bit.Core.Auth.Repositories; +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.Auth.Repositories; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -25,6 +30,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; private readonly IPasswordHasher _passwordHasher; + private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository; + private readonly IFeatureService _featureService; /// /// Instantiates a new @@ -36,16 +43,19 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand /// Provides a method to update re-encrypted send data /// Provides a method to update re-encrypted emergency access data /// Provides a method to update re-encrypted organization user data + /// Provides a method to update re-encrypted device keys /// Hashes the new master password /// Logs out user from other devices after successful rotation /// Provides a password mismatch error if master password hash validation fails /// Provides a method to update re-encrypted WebAuthn keys + /// Provides a method to update re-encrypted signature keys public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository, IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository, + IUserSignatureKeyPairRepository userSignatureKeyPairRepository, IFeatureService featureService) { _userService = userService; @@ -60,6 +70,8 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _identityErrorDescriber = errors; _credentialRepository = credentialRepository; _passwordHasher = passwordHasher; + _userSignatureKeyPairRepository = userSignatureKeyPairRepository; + _featureService = featureService; } /// @@ -80,50 +92,106 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand user.LastKeyRotationDate = now; user.SecurityStamp = Guid.NewGuid().ToString(); - if ( - !model.MasterPasswordUnlockData.ValidateForUser(user) - ) + List saveEncryptedDataActions = []; + + await UpdateAccountKeysAsync(model, user, saveEncryptedDataActions); + UpdateUnlockMethods(model, user, saveEncryptedDataActions); + UpdateUserData(model, user, saveEncryptedDataActions); + + await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + public async Task RotateV2AccountKeysAsync(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidateV2Encryption(model); + await ValidateVerifyingKeyUnchangedAsync(model, user); + + saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.UpdateForKeyRotation(user.Id, model.AccountKeys.SignatureKeyPairData)); + user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState; + user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion; + } + + public void UpgradeV1ToV2Keys(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidateV2Encryption(model); + saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.SetUserSignatureKeyPair(user.Id, model.AccountKeys.SignatureKeyPairData)); + user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState; + user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion; + } + + public async Task UpdateAccountKeysAsync(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidatePublicKeyEncryptionKeyPairUnchanged(model, user); + + if (IsV2EncryptionUserAsync(user)) { - throw new InvalidOperationException("The provided master password unlock data is not valid for this user."); + await RotateV2AccountKeysAsync(model, user, saveEncryptedDataActions); } - if ( - model.AccountPublicKey != user.PublicKey - ) + else if (model.AccountKeys.SignatureKeyPairData != null) { - throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric keypair is currently not supported during key rotation."); + UpgradeV1ToV2Keys(model, user, saveEncryptedDataActions); + } + else + { + if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.AesCbc256_HmacSha256_B64) + { + throw new InvalidOperationException("The provided account private key was not wrapped with AES-256-CBC-HMAC"); + } + // V1 user to V1 user rotation needs to further changes, the private key was re-encrypted. } - user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey; - user.PrivateKey = model.UserKeyEncryptedAccountPrivateKey; - user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash); - user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint; + // Private key is re-wrapped with new user key by client + user.PrivateKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + } + + public void UpdateUserData(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + // The revision date has to be updated so that de-synced clients don't accidentally post over the re-encrypted data + // with an old-user key-encrypted copy + var now = DateTime.UtcNow; - List saveEncryptedDataActions = new(); if (model.Ciphers.Any()) { - saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); + var ciphersWithUpdatedDate = model.Ciphers.ToList().Select(c => { c.RevisionDate = now; return c; }); + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, ciphersWithUpdatedDate)); } if (model.Folders.Any()) { - saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, model.Folders)); + var foldersWithUpdatedDate = model.Folders.ToList().Select(f => { f.RevisionDate = now; return f; }); + saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, foldersWithUpdatedDate)); } if (model.Sends.Any()) { - saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, model.Sends)); + var sendsWithUpdatedDate = model.Sends.ToList().Select(s => { s.RevisionDate = now; return s; }); + saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, sendsWithUpdatedDate)); } + } + + void UpdateUnlockMethods(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + if (!model.MasterPasswordUnlockData.ValidateForUser(user)) + { + throw new InvalidOperationException("The provided master password unlock data is not valid for this user."); + } + // Update master password authentication & unlock + user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey; + user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash); + user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint; if (model.EmergencyAccesses.Any()) { - saveEncryptedDataActions.Add( - _emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses)); + saveEncryptedDataActions.Add(_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses)); } if (model.OrganizationUsers.Any()) { - saveEncryptedDataActions.Add( - _organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers)); + saveEncryptedDataActions.Add(_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers)); } if (model.WebAuthnKeys.Any()) @@ -135,9 +203,80 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand { saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys)); } + } - await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; + private bool IsV2EncryptionUserAsync(User user) + { + // Returns whether the user is a V2 user based on the private key's encryption type. + ArgumentNullException.ThrowIfNull(user); + var isPrivateKeyEncryptionV2 = GetEncryptionType(user.PrivateKey) == EncryptionType.XChaCha20Poly1305_B64; + return isPrivateKeyEncryptionV2; + } + + private async Task ValidateVerifyingKeyUnchangedAsync(RotateUserAccountKeysData model, User user) + { + var currentSignatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(user.Id) ?? throw new InvalidOperationException("User does not have a signature key pair."); + if (model.AccountKeys.SignatureKeyPairData.VerifyingKey != currentSignatureKeyPair!.VerifyingKey) + { + throw new InvalidOperationException("The provided verifying key does not match the user's current verifying key."); + } + } + + private static void ValidatePublicKeyEncryptionKeyPairUnchanged(RotateUserAccountKeysData model, User user) + { + var publicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + if (publicKey != user.PublicKey) + { + throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric key pair is currently not supported during key rotation."); + } + } + + private static void ValidateV2Encryption(RotateUserAccountKeysData model) + { + if (model.AccountKeys.SignatureKeyPairData == null) + { + throw new InvalidOperationException("Signature key pair data is required for V2 encryption."); + } + if (GetEncryptionType(model.AccountKeys.SignatureKeyPairData.WrappedSigningKey) != EncryptionType.XChaCha20Poly1305_B64) + { + throw new InvalidOperationException("The provided signing key data is not wrapped with XChaCha20-Poly1305."); + } + if (string.IsNullOrEmpty(model.AccountKeys.SignatureKeyPairData.VerifyingKey)) + { + throw new InvalidOperationException("The provided signature key pair data does not contain a valid verifying key."); + } + + if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.XChaCha20Poly1305_B64) + { + throw new InvalidOperationException("The provided private key encryption key is not wrapped with XChaCha20-Poly1305."); + } + if (string.IsNullOrEmpty(model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey)) + { + throw new InvalidOperationException("No signed public key provided, but the user already has a signature key pair."); + } + if (model.AccountKeys.SecurityStateData == null || string.IsNullOrEmpty(model.AccountKeys.SecurityStateData.SecurityState)) + { + throw new InvalidOperationException("No signed security state provider for V2 user"); + } + } + + /// + /// Helper method to convert an encryption type string to an enum value. + /// + private static EncryptionType GetEncryptionType(string encString) + { + var parts = encString.Split('.'); + if (parts.Length == 1) + { + throw new ArgumentException("Invalid encryption type string."); + } + if (byte.TryParse(parts[0], out var encryptionTypeNumber)) + { + if (Enum.IsDefined(typeof(EncryptionType), encryptionTypeNumber)) + { + return (EncryptionType)encryptionTypeNumber; + } + } + throw new ArgumentException("Invalid encryption type string."); } } diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs new file mode 100644 index 0000000000..352bb447c8 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs @@ -0,0 +1,691 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Verify your email to access this Bitwarden Send +

    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + +
    + + + + + + + + + + + + + + + + + + + + + +
    + +
    Your verification code is:
    + +
    + +
    {{Token}}
    + +
    + +
    + +
    + +
    This code expires in {{Expiry}} minutes. After that, you'll need + to verify your email again.
    + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + +
    + + + + + + + + + +
    + +

    + Bitwarden Send transmits sensitive, temporary information to + others easily and securely. Learn more about + Bitwarden Send + or + sign up + to try it today. +

    + +
    + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs new file mode 100644 index 0000000000..7c9c1db527 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs @@ -0,0 +1,9 @@ +{{#>BasicTextLayout}} +Verify your email to access this Bitwarden Send. + +Your verification code is: {{Token}} + +This code can only be used once and expires in {{Expiry}} minutes. After that you'll need to verify your email again. + +Bitwarden Send transmits sensitive, temporary information to others easily and securely. Learn more about Bitwarden Send or sign up to try it today. +{{/BasicTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs new file mode 100644 index 0000000000..be1a3854b5 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.html.hbs @@ -0,0 +1,815 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + You can now share passwords with members of {{OrganizationName}}! +

    + +
    + + + + + + + +
    + + Log in + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    As a member of {{OrganizationName}}:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Organization Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    Your account is owned by {{OrganizationName}} and is subject to their security and management policies.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Group Users Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + +
    You can easily access and share passwords with your team.
    + +
    + + + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs new file mode 100644 index 0000000000..38c45f2dd1 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.text.hbs @@ -0,0 +1,4 @@ +{{#>TitleContactUsTextLayout}} + You may now access logins and other items {{OrganizationName}} has shared with you from your Bitwarden vault. + Tip: Use the Bitwarden mobile app to quickly save logins and auto-fill forms. Download from the App Store or Google Play. +{{/TitleContactUsTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs new file mode 100644 index 0000000000..b9984343d5 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.html.hbs @@ -0,0 +1,983 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + You can now share passwords with members of {{OrganizationName}}! +

    + +
    + + + + + + + +
    + + Log in + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    As a member of {{OrganizationName}}:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Collections Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    You can access passwords {{OrganizationName}} has shared with you.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Group Users Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + +
    You can easily share passwords with friends, family, or coworkers.
    + +
    + + + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + +
    Download Bitwarden on all devices
    + +
    + +
    Already using the browser extension? + Download the Bitwarden mobile app from the + App Store + or Google Play + to quickly save logins and autofill forms on the go.
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + + Download on the App Store + + + +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + + Get it on Google Play + + + +
    + +
    + +
    + + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs new file mode 100644 index 0000000000..38c45f2dd1 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.text.hbs @@ -0,0 +1,4 @@ +{{#>TitleContactUsTextLayout}} + You may now access logins and other items {{OrganizationName}} has shared with you from your Bitwarden vault. + Tip: Use the Bitwarden mobile app to quickly save logins and auto-fill forms. Download from the App Store or Google Play. +{{/TitleContactUsTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs new file mode 100644 index 0000000000..1998cf10ba --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs @@ -0,0 +1,904 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    An administrator from {{OrganizationName}} will approve you + before you can share passwords. While you wait for approval, get + started with Bitwarden Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Devices Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Take your passwords with you anywhere.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs new file mode 100644 index 0000000000..38f53e7755 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.text.hbs @@ -0,0 +1,19 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +A {{OrganizationName}} administrator will approve you before you can share passwords. +While you wait for approval, get started with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Download Bitwarden on all devices: +Take your passwords with you anywhere. (https://www.bitwarden.com/download) + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs new file mode 100644 index 0000000000..2ad670383b --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs @@ -0,0 +1,903 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    Follow these simple steps to get up and running with Bitwarden + Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Devices Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Take your passwords with you anywhere.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs new file mode 100644 index 0000000000..f698e79ca7 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.text.hbs @@ -0,0 +1,18 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +Follow these simple steps to get up and running with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Download Bitwarden on all devices: +Take your passwords with you anywhere. (https://bitwarden.com/help/auto-fill-browser/) + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs new file mode 100644 index 0000000000..efcacf1866 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs @@ -0,0 +1,904 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Welcome to Bitwarden! +

    + +

    + Let’s get you set up to autofill. +

    +
    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +
    An administrator from {{OrganizationName}} will need to confirm + you before you can share passwords. Get started with Bitwarden + Password Manager:
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Browser Extension Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    With the Bitwarden extension, you can fill passwords with one click.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Install Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Quickly transfer existing passwords to Bitwarden using the importer.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + Autofill Icon + +
    + +
    + +
    + + + +
    + + + + + + + + + + + + + +
    + + + +
    + +
    Fill your passwords securely with one click.
    + +
    + +
    + + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs new file mode 100644 index 0000000000..3808cc818d --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.text.hbs @@ -0,0 +1,20 @@ +{{#>FullTextLayout}} +Welcome to Bitwarden! +Let's get you set up with autofill. + +A {{OrganizationName}} administrator will approve you before you can share passwords. +Get started with Bitwarden Password Manager: + +Get the browser extension: +With the Bitwarden extension, you can fill passwords with one click. (https://www.bitwarden.com/download) + +Add passwords to your vault: +Quickly transfer existing passwords to Bitwarden using the importer. (https://bitwarden.com/help/import-data/) + +Try Bitwarden autofill: +Fill your passwords securely with one click. (https://bitwarden.com/help/auto-fill-browser/) + + +Learn more about Bitwarden +Find user guides, product documentation, and videos on the Bitwarden Help Center. (https://bitwarden.com/help/) +{{/FullTextLayout}} diff --git a/src/Core/MailTemplates/Mjml/.mjmlconfig b/src/Core/MailTemplates/Mjml/.mjmlconfig index 7560e0fb96..a71e3b5ee9 100644 --- a/src/Core/MailTemplates/Mjml/.mjmlconfig +++ b/src/Core/MailTemplates/Mjml/.mjmlconfig @@ -1,5 +1,9 @@ { "packages": [ - "components/hero" + "components/mj-bw-hero", + "components/mj-bw-simple-hero", + "components/mj-bw-icon-row", + "components/mj-bw-learn-more-footer", + "emails/AdminConsole/components/mj-bw-inviter-info" ] } diff --git a/src/Core/MailTemplates/Mjml/README.md b/src/Core/MailTemplates/Mjml/README.md index b60655140a..fabb393ee0 100644 --- a/src/Core/MailTemplates/Mjml/README.md +++ b/src/Core/MailTemplates/Mjml/README.md @@ -1,19 +1,126 @@ -# Email templates +# `MJML` email templating -This directory contains MJML templates for emails sent by the application. MJML is a markup language designed to reduce the pain of coding responsive email templates. +This directory contains `MJML` templates for emails. `MJML` is a markup language designed to reduce the pain of coding responsive email templates. Component-based development features in `MJML` improve code quality and reusability. -## Usage +> [!TIP] +> `MJML` stands for MailJet Markup Language. -```bash +## Implementation considerations + +`MJML` templates are compiled into `HTML`, and those outputs are then consumed by Handlebars to render the final email for delivery. It builds on top of our existing infrastructure and means we can continue to use the double brace (`{{}}`) syntax within `MJML`, since Handlebars will assign values to those `{{variables}}`. + +To do this, there is an added step where we compile `*.mjml` to `*.html.hbs`. `*.html.hbs` is the format we use so the Handlebars service can apply the variables. This build pipeline process is in progress and may need to be manually done at times. + +### `*.txt.hbs` + +There is no change to how we create the `txt.hbs`. MJML does not impact how we create these artifacts. + +## Building `MJML` files + +```shell npm ci -# Build once +# Build *.html to ./out directory npm run build -# To build on changes -npm run watch +# To build on changes to *.mjml and *.js files, new files will not be tracked, you will need to run again +npm run build:watch + +# Build *.html.hbs to ./out directory +npm run build:hbs + +# Build minified *.html.hbs to ./out directory +npm run build:minify + +# apply prettier formatting +npm run prettier ``` -## Development +## Development process -MJML supports components and you can create your own components by adding them to `.mjmlconfig`. +`MJML` supports components and you can create your own components by adding them to `.mjmlconfig`. Components are simple JavaScript that return `MJML` markup based on the attributes assigned, see components/mj-bw-hero.js. The markup is not a proper object, but contained in a string. + +When using `MJML` templating you can use the above [commands](#building-mjml-files) to compile the template and view it in a web browser. + +Not all `MJML` tags have the same attributes, it is highly recommended to review the documentation on the official MJML website to understand the usages of each of the tags. + +### Developing the mail template + +1. Create `cool-email.mjml` in appropriate team directory. +2. Run `npm run build:watch`. +3. View compiled `HTML` output in a web browser. +4. Iterate through your development. While running `build:watch` you should be able to refresh the browser page after the `mjml/js` recompile to see the changes. + +### Testing the mail template with `IMailer` + +After the email is developed in the [initial step](#developing-the-mail-template), we need to make sure that the email `{{variables}}` are populated properly by Handlebars. We can do this by running it through an `IMailer` implementation. The `IMailer`, documented [here](../../Platform/Mail/README.md#step-3-create-handlebars-templates), requires that the ViewModel, the `.html.hbs` `MJML` build artifact, and `.text.hbs` files be in the same directory. + +1. Run `npm run build:hbs`. +2. Copy built `*.html.hbs` files from the build directory to the directory that the `IMailer` expects. All files in the `Core/MailTemplates/Mjml/out` directory should be copied to the `/src/Core/MailTemplates/Mjml` directory, ensuring that the files are in the same directory as the corresponding ViewModels. If a shared component is modified it is important to copy and overwrite all files in that directory to capture changes in the `*.html.hbs` files. +3. Run code that will send the email. + +The minified `html.hbs` artifacts are deliverables and must be placed into the correct `/src/Core/MailTemplates/Mjml` directories in order to be used by `IMailer` implementations, see step 2 above. + +### Testing the mail template with `IMailService` + +> [!WARNING] +> The `IMailService` has been deprecated. The [IMailer](#testing-the-mail-template-with-imailer) should be used instead. + +After the email is developed from the [initial step](#developing-the-mail-template), make sure the email `{{variables}}` are populated properly by running it through an `IMailService` implementation. + +1. Run `npm run build:hbs` +2. Copy built `*.html.hbs` files from the build directory to a location the mail service can consume them. + 1. All files in the `Core/MailTemplates/Mjml/out` directory should be copied to the `src/Core/MailTemplates/Handlebars/MJML` directory. If a shared component is modified it is important to copy and overwrite all files in that directory to capture changes in the `*.html.hbs`. +3. Run code that will send the email. + +The minified `html.hbs` artifacts are deliverables and must be placed into the correct `src/Core/MailTemplates/Handlebars/` directories in order to be used by `IMailService` implementations, see 2.1 above. + +### Custom tags + +There is currently a `mj-bw-hero` tag you can use within your `*.mjml` templates. This is a good example of how to create a component that takes in attribute values allowing us to be more DRY in our development of emails. Since the attribute's input is a string we are able to define whatever we need into the component, in this case `mj-bw-hero`. + +In order to view the custom component you have written you will need to include it in the `.mjmlconfig` and reference it in a `.mjml` template file. +```html + + +``` + +Attributes in custom components are defined by the developer. They can be required or optional depending on implementation. See the official `MJML` [documentation](https://documentation.mjml.io/#components) for more information. +```js +static allowedAttributes = { + "img-src": "string", // REQUIRED: Source for the image displayed in the right-hand side of the blue header area + title: "string", // REQUIRED: large text stating primary purpose of the email + "button-text": "string", // OPTIONAL: text to display in the button + "button-url": "string", // OPTIONAL: URL to navigate to when the button is clicked + "sub-title": "string", // OPTIONAL: smaller text providing additional context for the title +}; + +static defaultAttributes = {}; +``` + +Custom components, such as `mj-bw-hero`, must be defined in the `.mjmlconfig` in order for them to be compiled and rendered properly in the templates. + +```json +{ + "packages": ["components/mj-bw-hero"] +} +``` + +### `mj-include` + +You are also able to reference other more static `MJML` templates in your `MJML` file simply by referencing the file within the `MJML` template. + +```html + + + + +``` + +#### `head.mjml` +Currently we include the `head.mjml` file in all `MJML` templates as it contains shared styling and formatting that ensures consistency across all email implementations. + +In the future we may deviate from this practice to support different layouts. At that time we will modify the docs with direction. diff --git a/src/Core/MailTemplates/Mjml/build.js b/src/Core/MailTemplates/Mjml/build.js new file mode 100644 index 0000000000..4e3eaef449 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/build.js @@ -0,0 +1,130 @@ +const mjml2html = require("mjml"); +const { registerComponent } = require("mjml-core"); +const fs = require("fs"); +const path = require("path"); +const glob = require("glob"); + +// Parse command line arguments +const args = process.argv.slice(2); // Remove 'node' and script path + +// Parse flags +const flags = { + minify: args.includes("--minify") || args.includes("-m"), + watch: args.includes("--watch") || args.includes("-w"), + hbs: args.includes("--hbs") || args.includes("-h"), + trace: args.includes("--trace") || args.includes("-t"), + clean: args.includes("--clean") || args.includes("-c"), + help: args.includes("--help"), +}; + +// Use __dirname to get absolute paths relative to the script location +const config = { + inputDir: path.join(__dirname, "emails"), + outputDir: path.join(__dirname, "out"), + minify: flags.minify, + validationLevel: "strict", + hbsOutput: flags.hbs, +}; + +// Debug output +if (flags.trace) { + console.log("[DEBUG] Script location:", __dirname); + console.log("[DEBUG] Input directory:", config.inputDir); + console.log("[DEBUG] Output directory:", config.outputDir); +} + +// Ensure output directory exists +if (!fs.existsSync(config.outputDir)) { + fs.mkdirSync(config.outputDir, { recursive: true }); + if (flags.trace) { + console.log("[INFO] Created output directory:", config.outputDir); + } +} + +// Find all MJML files with absolute paths, excluding components directories +const mjmlFiles = glob.sync(`${config.inputDir}/**/*.mjml`, { + ignore: ['**/components/**'] +}); + +console.log(`\n[INFO] Found ${mjmlFiles.length} MJML file(s) to compile...`); + +if (mjmlFiles.length === 0) { + console.error("[ERROR] No MJML files found!"); + console.error("[ERROR] Looked in:", config.inputDir); + console.error( + "[ERROR] Does this directory exist?", + fs.existsSync(config.inputDir), + ); + process.exit(1); +} + +// Compile each MJML file +let successCount = 0; +let errorCount = 0; + +mjmlFiles.forEach((filePath) => { + try { + const mjmlContent = fs.readFileSync(filePath, "utf8"); + const fileName = path.basename(filePath, ".mjml"); + const relativePath = path.relative(config.inputDir, filePath); + + console.log(`\n[BUILD] Compiling: ${relativePath}`); + + // Compile MJML to HTML + const result = mjml2html(mjmlContent, { + minify: config.minify, + validationLevel: config.validationLevel, + filePath: filePath, // Important: tells MJML where the file is for resolving includes + mjmlConfigPath: __dirname, // Point to the directory with .mjmlconfig + }); + + // Check for errors + if (result.errors.length > 0) { + console.error(`[ERROR] Failed to compile ${fileName}.mjml:`); + result.errors.forEach((err) => + console.error(` ${err.formattedMessage}`), + ); + errorCount++; + return; + } + + // Calculate output path preserving directory structure + const relativeDir = path.dirname(relativePath); + const outputDir = path.join(config.outputDir, relativeDir); + + // Ensure subdirectory exists + if (!fs.existsSync(outputDir)) { + fs.mkdirSync(outputDir, { recursive: true }); + } + + const outputExtension = config.hbsOutput ? ".html.hbs" : ".html"; + const outputPath = path.join(outputDir, `${fileName}${outputExtension}`); + fs.writeFileSync(outputPath, result.html); + + console.log( + `[OK] Built: ${fileName}.mjml → ${path.relative(__dirname, outputPath)}`, + ); + successCount++; + + // Log warnings if any + if (result.warnings && result.warnings.length > 0) { + console.warn(`[WARN] Warnings for ${fileName}.mjml:`); + result.warnings.forEach((warn) => + console.warn(` ${warn.formattedMessage}`), + ); + } + } catch (error) { + console.error(`[ERROR] Exception processing ${path.basename(filePath)}:`); + console.error(` ${error.message}`); + errorCount++; + } +}); + +console.log(`\n[SUMMARY] Compilation complete!`); +console.log(` Success: ${successCount}`); +console.log(` Failed: ${errorCount}`); +console.log(` Output: ${config.outputDir}`); + +if (errorCount > 0) { + process.exit(1); +} diff --git a/src/Core/MailTemplates/Mjml/build.sh b/src/Core/MailTemplates/Mjml/build.sh deleted file mode 100755 index c76bdd8f61..0000000000 --- a/src/Core/MailTemplates/Mjml/build.sh +++ /dev/null @@ -1,4 +0,0 @@ -# TODO: This should probably be replaced with a node script building every file in `emails/` - -npx mjml emails/invite.mjml -o out/invite.html -npx mjml emails/two-factor.mjml -o out/two-factor.html diff --git a/src/Core/MailTemplates/Mjml/components/footer.mjml b/src/Core/MailTemplates/Mjml/components/footer.mjml index 0634033618..4037d6c9ba 100644 --- a/src/Core/MailTemplates/Mjml/components/footer.mjml +++ b/src/Core/MailTemplates/Mjml/components/footer.mjml @@ -1,52 +1,52 @@ - + - + -

    +

    © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

    Always confirm you are on a trusted Bitwarden domain before logging in:
    - bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

    diff --git a/src/Core/MailTemplates/Mjml/components/head.mjml b/src/Core/MailTemplates/Mjml/components/head.mjml index 929057fb70..4cb27889eb 100644 --- a/src/Core/MailTemplates/Mjml/components/head.mjml +++ b/src/Core/MailTemplates/Mjml/components/head.mjml @@ -4,13 +4,21 @@ font-size="16px" /> - + - .link { text-decoration: none; color: #175ddc; font-weight: 600 } + .link { + text-decoration: none; + color: #175ddc; + font-weight: 600; + } - .border-fix > table { border-collapse:separate !important; } .border-fix > - table > tbody > tr > td { border-radius: 3px; } + .border-fix > table { + border-collapse: separate !important; + } + .border-fix > table > tbody > tr > td { + border-radius: 3px; + } diff --git a/src/Core/MailTemplates/Mjml/components/hero.js b/src/Core/MailTemplates/Mjml/components/mj-bw-hero.js similarity index 51% rename from src/Core/MailTemplates/Mjml/components/hero.js rename to src/Core/MailTemplates/Mjml/components/mj-bw-hero.js index 6c5bd9bc99..c7a3b7e7ff 100644 --- a/src/Core/MailTemplates/Mjml/components/hero.js +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-hero.js @@ -9,20 +9,50 @@ class MjBwHero extends BodyComponent { }; static allowedAttributes = { - "img-src": "string", - title: "string", - "button-text": "string", - "button-url": "string", + "img-src": "string", // REQUIRED: Source for the image displayed in the right-hand side of the blue header area + title: "string", // REQUIRED: large text stating primary purpose of the email + "button-text": "string", // OPTIONAL: text to display in the button + "button-url": "string", // OPTIONAL: URL to navigate to when the button is clicked + "sub-title": "string", // OPTIONAL: smaller text providing additional context for the title }; static defaultAttributes = {}; + componentHeadStyle = breakpoint => { + return ` + @media only screen and (max-width:${breakpoint}) { + .mj-bw-hero-responsive-img { + display: none !important; + } + } + ` + } + render() { - return this.renderMJML(` + const buttonElement = this.getAttribute("button-text") && this.getAttribute("button-url") ? + ` + ${this.getAttribute("button-text")} + ` : ""; + const subTitleElement = this.getAttribute("sub-title") ? + ` +

    + ${this.getAttribute("sub-title")} +

    +
    ` : ""; + + return this.renderMJML( + ` ${this.getAttribute("title")} -
    - - ${this.getAttribute("button-text")} - + ` + + subTitleElement + + ` +
    ` + + buttonElement + + ` + width="155px" + padding="0px" + css-class="mj-bw-hero-responsive-img" + />
    - `); + `, + ); } } diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js new file mode 100644 index 0000000000..d0ccde5513 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js @@ -0,0 +1,105 @@ +const { BodyComponent } = require("mjml-core"); + +const BODY_TEXT_STYLES = ` + font-family="Roboto, 'Helvetica Neue', Helvetica, Arial, sans-serif" + font-size="16px" + font-weight="400" + line-height="24px" +`; + +class MjBwIconRow extends BodyComponent { + static dependencies = { + "mj-column": ["mj-bw-icon-row"], + "mj-wrapper": ["mj-bw-icon-row"], + "mj-bw-icon-row": [], + }; + + static allowedAttributes = { + "icon-src": "string", + "icon-alt": "string", + "head-url-text": "string", + "head-url": "string", + text: "string", + "foot-url-text": "string", + "foot-url": "string", + }; + + static defaultAttributes = {}; + + headStyle = (breakpoint) => { + return ` + @media only screen and (max-width:${breakpoint}) { + .mj-bw-icon-row-text { + padding-left: 5px !important; + line-height: 20px; + } + .mj-bw-icon-row { + padding: 10px 15px; + width: fit-content !important; + } + } + `; + }; + + render() { + const headAnchorElement = + this.getAttribute("head-url-text") && this.getAttribute("head-url") + ? ` + + + ${this.getAttribute("head-url-text")} + + External Link Icon + + + ` + : ""; + + const footAnchorElement = + this.getAttribute("foot-url-text") && this.getAttribute("foot-url") + ? ` + + ${this.getAttribute("foot-url-text")} + + External Link Icon + + + ` + : ""; + + return this.renderMJML( + ` + + + + + + + ${headAnchorElement} + + ${this.getAttribute("text")} + + ${footAnchorElement} + + + + `, + ); + } +} + +module.exports = MjBwIconRow; diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js b/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js new file mode 100644 index 0000000000..fb8b5b69dd --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js @@ -0,0 +1,51 @@ +const { BodyComponent } = require("mjml-core"); +class MjBwLearnMoreFooter extends BodyComponent { + static dependencies = { + // Tell the validator which tags are allowed as our component's parent + "mj-column": ["mj-bw-learn-more-footer"], + "mj-wrapper": ["mj-bw-learn-more-footer"], + // Tell the validator which tags are allowed as our component's children + "mj-bw-learn-more-footer": [], + }; + + static allowedAttributes = {}; + + static defaultAttributes = {}; + + componentHeadStyle = (breakpoint) => { + return ` + @media only screen and (max-width:${breakpoint}) { + .mj-bw-learn-more-footer-responsive-img { + display: none !important; + } + } + `; + }; + + render() { + return this.renderMJML( + ` + + + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center. +
    +
    + + + +
    + `, + ); + } +} + +module.exports = MjBwLearnMoreFooter; diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js b/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js new file mode 100644 index 0000000000..e7364e34b0 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-simple-hero.js @@ -0,0 +1,40 @@ +const { BodyComponent } = require("mjml-core"); + +class MjBwSimpleHero extends BodyComponent { + static dependencies = { + // Tell the validator which tags are allowed as our component's parent + "mj-column": ["mj-bw-simple-hero"], + "mj-wrapper": ["mj-bw-simple-hero"], + // Tell the validator which tags are allowed as our component's children + "mj-bw-simple-hero": [], + }; + + static allowedAttributes = {}; + + static defaultAttributes = {}; + + render() { + return this.renderMJML( + ` + + + + + + `, + ); + } +} + +module.exports = MjBwSimpleHero; diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml new file mode 100644 index 0000000000..24f85af31c --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-enterprise-teams.mjml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + As a member of {{OrganizationName}}: + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml new file mode 100644 index 0000000000..2e48e82f84 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/OrganizationConfirmation/organization-confirmation-family-free.mjml @@ -0,0 +1,55 @@ + + + + + + + + + + + + + + + + + As a member of {{OrganizationName}}: + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js new file mode 100644 index 0000000000..e9d392f570 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mj-bw-inviter-info.js @@ -0,0 +1,35 @@ +const { BodyComponent } = require("mjml-core"); + +class MjBwInviterInfo extends BodyComponent { + + static dependencies = { + "mj-column": ["mj-bw-inviter-info"], + "mj-wrapper": ["mj-bw-inviter-info"], + "mj-bw-inviter-info": [], + }; + + static allowedAttributes = { + "expiration-date": "string", // REQUIRED: Date to display + "email-address": "string", // Optional: Email address to display + }; + + render() { + const emailAddressText = this.getAttribute("email-address") + ? `This invitation was sent by ${this.getAttribute("email-address")} and expires ` + : "This invitation expires "; + + return this.renderMJML( + ` + + + + ${emailAddressText + this.getAttribute("expiration-date")} + + + + ` + ); + } +} + +module.exports = MjBwInviterInfo; diff --git a/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml new file mode 100644 index 0000000000..8e990dc924 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/AdminConsole/components/mobile-app-download.mjml @@ -0,0 +1,38 @@ + + + + + Download Bitwarden on all devices + + + Already using the browser extension? + Download the Bitwarden mobile app from the + App Store + or Google Play + to quickly save logins and autofill forms on the go. + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml new file mode 100644 index 0000000000..7c81a700f2 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + An administrator from {{OrganizationName}} will approve you + before you can share passwords. While you wait for approval, get + started with Bitwarden Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml new file mode 100644 index 0000000000..4fc9bc466a --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-individual-user.mjml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + Follow these simple steps to get up and running with Bitwarden + Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml new file mode 100644 index 0000000000..7b8a03dc7e --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + An administrator from {{OrganizationName}} will need to confirm + you before you can share passwords. Get started with Bitwarden + Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml new file mode 100644 index 0000000000..660bbf0b45 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml @@ -0,0 +1,69 @@ + + + + + .send-bubble { + padding-left: 20px; + padding-right: 20px; + width: 90% !important; + } + + + + + + + + + + + + + + Your verification code is: + + {{Token}} + + + + This code expires in {{Expiry}} minutes. After that, you'll need + to verify your email again. + + + + + + +

    + Bitwarden Send transmits sensitive, temporary information to + others easily and securely. Learn more about + Bitwarden Send + or + sign up + to try it today. +

    +
    +
    +
    +
    + + + + + + + + +
    +
    diff --git a/src/Core/MailTemplates/Mjml/emails/two-factor.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml similarity index 61% rename from src/Core/MailTemplates/Mjml/emails/two-factor.mjml rename to src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml index b959ec1c8a..73d205ba57 100644 --- a/src/Core/MailTemplates/Mjml/emails/two-factor.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml @@ -1,10 +1,10 @@ - + - + -

    Your two-step verification code is: {{Token}}

    +

    + Your two-step verification code is: {{ Token }} +

    Use this code to complete logging in with Bitwarden.

    - + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml new file mode 100644 index 0000000000..092ae303de --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually + at {{BaseAnnualRenewalPrice}} + tax. + + + As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + + + Questions? Contact + support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml new file mode 100644 index 0000000000..dcf193875a --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2020-renewal.mjml @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + + + Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually. + + + Questions? Contact support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml new file mode 100644 index 0000000000..a460442a7c --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. + + + As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + + + Questions? Contact + support@bitwarden.com + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/invite.mjml b/src/Core/MailTemplates/Mjml/emails/invite.mjml index 4eae12d0dc..8e08a6753a 100644 --- a/src/Core/MailTemplates/Mjml/emails/invite.mjml +++ b/src/Core/MailTemplates/Mjml/emails/invite.mjml @@ -22,26 +22,7 @@ - - - -

    - We’re here for you! -

    - If you have any questions, search the Bitwarden - Help - site or - contact us. -
    -
    - - - -
    + diff --git a/src/Core/MailTemplates/Mjml/package.json b/src/Core/MailTemplates/Mjml/package.json index 8a8f81e845..f74279da7b 100644 --- a/src/Core/MailTemplates/Mjml/package.json +++ b/src/Core/MailTemplates/Mjml/package.json @@ -15,8 +15,10 @@ }, "homepage": "https://bitwarden.com", "scripts": { - "build": "./build.sh", - "watch": "nodemon --exec ./build.sh --watch ./components --watch ./emails --ext js,mjml", + "build": "node ./build.js", + "build:hbs": "node ./build.js --hbs", + "build:minify": "node ./build.js --hbs --minify", + "build:watch": "nodemon ./build.js --watch emails --watch components --ext mjml,js", "prettier": "prettier --cache --write ." }, "dependencies": { diff --git a/src/Core/MailTemplates/README.md b/src/Core/MailTemplates/README.md new file mode 100644 index 0000000000..312821afd3 --- /dev/null +++ b/src/Core/MailTemplates/README.md @@ -0,0 +1,88 @@ +Email templating +================ + +We use MJML to generate the HTML that our mail services use to send emails to users. To accomplish this, we use different file types depending on which part of the email generation process we're working with. + +# File Types + +## `*.html.hbs` +These are the compiled HTML email templates that serve as the foundation for all HTML emails sent by the Bitwarden platform. They are generated from MJML source files and enhanced with Handlebars templating capabilities. + +### Generation Process +- **Source**: Built from `*.mjml` files in the `./mjml` directory. + - The MJML source acts as a toolkit for developers to generate HTML. It is the developers responsibility to generate the HTML and then ensure it is accessible to `IMailService` implementations. +- **Build Tool**: Generated via node build scripts: `npm run build`. + - The build script definitions can be viewed in the `Mjml/package.json` as well as in `Mjml/build.js`. +- **Output**: Cross-client compatible HTML with embedded CSS for maximum email client support +- **Template Engine**: Enhanced with Handlebars syntax for dynamic content injection + +### Handlebars Integration +The templates use Handlebars templating syntax for dynamic content replacement: + +```html + +

    Welcome {{userName}}!

    +

    Your organization {{organizationName}} has invited you to join.

    +Accept Invitation +``` + +**Variable Types:** +- **Simple Variables**: `{{userName}}`, `{{email}}`, `{{organizationName}}` + +### Email Service Integration +The `IMailService` consumes these templates through the following process: + +1. **Template Selection**: Service selects appropriate `.html.hbs` template based on email type +2. **Model Binding**: View model properties are mapped to Handlebars variables +3. **Compilation**: Handlebars engine processes variables and generates final HTML + +### Development Guidelines + +**Variable Naming:** +- Use camelCase for consistency: `{{userName}}`, `{{organizationName}}` +- Prefix URLs with descriptive names: `{{actionUrl}}`, `{{logoUrl}}` + +**Testing Considerations:** +- Verify Handlebars variable replacement with actual view model data +- Ensure graceful degradation when variables are missing or null, if necessary +- Validate HTML structure and accessibility compliance + +## `*.txt.hbs` +These files provide plain text versions of emails and are essential for email accessibility and deliverability. They serve several important purposes: + +### Purpose and Usage +- **Accessibility**: Screen readers and assistive technologies often work better with plain text versions +- **Email Client Compatibility**: Some email clients prefer or only display plain text versions +- **Fallback Content**: When HTML rendering fails, the plain text version ensures the message is still readable + +### Structure +Plain text email templates use the same Handlebars syntax (`{{variable}}`) as HTML templates for dynamic content replacement. They should: + +- Contain the core message content without HTML formatting +- Use line breaks and spacing for readability +- Include all important links as full URLs +- Maintain logical content hierarchy using spacing and simple text formatting + +### Email Service Integration +The `IMailService` automatically uses both versions when sending emails: +- The HTML version (from `*.html.hbs`) provides rich formatting and styling +- The plain text version (from `*.txt.hbs`) serves as the text alternative +- Email clients can choose which version to display based on user preferences and capabilities + +### Development Guidelines +- Always create a corresponding `*.txt.hbs` file for each `*.html.hbs` template +- Keep the content concise but complete - include all essential information from the HTML version +- Test plain text templates to ensure they're readable and convey the same message + +## `*.mjml` +This is a templating language we use to increase efficiency when creating email content. See the `MJML` [documentation](./Mjml/README.md) for more details. + +# Managing email assets + +We host assets that are included in emails at `assets.bitwarden.com`, at the `/email/v1` path. This corresponds to a static file storage container that is managed by our SRE team. For example: https://assets.bitwarden.com/email/v1/mail-github.png. This is the URL for all assets for emails sent from any environment. + +## Adding an asset + +When you are creating an email that needs a new asset, you should first check to see if that asset already exists. The easiest way to do this is check at the corresponding `https://assets.bitwarden.com/email/v1/` URL (e.g. https://assets.bitwarden.com/email/v1/my_new_image.png). + +If the asset you are adding is not there, enter a ticket for the SRE team to add the asset to the email asset container. The preferred format for assets is a `.png` file, and the file(s) should be attached to the ticket. \ No newline at end of file diff --git a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs index 7473738ffc..aa49c25d36 100644 --- a/src/Core/Models/Business/CompleteSubscriptionUpdate.cs +++ b/src/Core/Models/Business/CompleteSubscriptionUpdate.cs @@ -299,7 +299,7 @@ public class CompleteSubscriptionUpdate : SubscriptionUpdate ? organization.SmServiceAccounts - plan.SecretsManager.BaseServiceAccount : 0, PurchasedAdditionalStorage = organization.MaxStorageGb.HasValue - ? organization.MaxStorageGb.Value - (plan.PasswordManager.BaseStorageGb ?? 0) : + ? organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb : 0 }; } diff --git a/src/Core/Models/Business/OrganizationUpgrade.cs b/src/Core/Models/Business/OrganizationUpgrade.cs index 89b9a5e6f2..d165a96d0a 100644 --- a/src/Core/Models/Business/OrganizationUpgrade.cs +++ b/src/Core/Models/Business/OrganizationUpgrade.cs @@ -2,6 +2,7 @@ #nullable disable using Bit.Core.Billing.Enums; +using Bit.Core.KeyManagement.Models.Data; namespace Bit.Core.Models.Business; @@ -13,8 +14,7 @@ public class OrganizationUpgrade public short AdditionalStorageGb { get; set; } public bool PremiumAccessAddon { get; set; } public TaxInfo TaxInfo { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public PublicKeyEncryptionKeyPairData Keys { get; set; } public int? AdditionalSmSeats { get; set; } public int? AdditionalServiceAccounts { get; set; } public bool UseSecretsManager { get; set; } diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index a016ac54f3..68a060b4a8 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -1,52 +1,119 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Stripe; +#nullable enable + namespace Bit.Core.Models.Business; public class SubscriptionInfo { - public BillingCustomerDiscount CustomerDiscount { get; set; } - public BillingSubscription Subscription { get; set; } - public BillingUpcomingInvoice UpcomingInvoice { get; set; } + /// + /// Converts Stripe's minor currency units (cents) to major currency units (dollars). + /// IMPORTANT: Only supports USD. All Bitwarden subscriptions are USD-only. + /// + private const decimal StripeMinorUnitDivisor = 100M; + /// + /// Converts Stripe's minor currency units (cents) to major currency units (dollars). + /// Preserves null semantics to distinguish between "no amount" (null) and "zero amount" (0.00m). + /// + /// The amount in Stripe's minor currency units (e.g., cents for USD). + /// The amount in major currency units (e.g., dollars for USD), or null if the input is null. + private static decimal? ConvertFromStripeMinorUnits(long? amountInCents) + { + return amountInCents.HasValue ? amountInCents.Value / StripeMinorUnitDivisor : null; + } + + public BillingCustomerDiscount? CustomerDiscount { get; set; } + public BillingSubscription? Subscription { get; set; } + public BillingUpcomingInvoice? UpcomingInvoice { get; set; } + + /// + /// Represents customer discount information from Stripe billing. + /// public class BillingCustomerDiscount { public BillingCustomerDiscount() { } + /// + /// Creates a BillingCustomerDiscount from a Stripe Discount object. + /// + /// The Stripe discount containing coupon and expiration information. public BillingCustomerDiscount(Discount discount) { Id = discount.Coupon?.Id; + // Active = true only for perpetual/recurring discounts (no end date) + // This is intentional for Milestone 2 - only perpetual discounts are shown in UI Active = discount.End == null; PercentOff = discount.Coupon?.PercentOff; - AppliesTo = discount.Coupon?.AppliesTo?.Products ?? []; + AmountOff = ConvertFromStripeMinorUnits(discount.Coupon?.AmountOff); + // Stripe's CouponAppliesTo.Products is already IReadOnlyList, so no conversion needed + AppliesTo = discount.Coupon?.AppliesTo?.Products; } - public string Id { get; set; } + /// + /// The Stripe coupon ID (e.g., "cm3nHfO1"). + /// Note: Only specific coupon IDs are displayed in the UI based on feature flag configuration, + /// though Stripe may apply additional discounts that are not shown. + /// + public string? Id { get; set; } + + /// + /// True only for perpetual/recurring discounts (End == null). + /// False for any discount with an expiration date, even if not yet expired. + /// Product decision for Milestone 2: only show perpetual discounts in UI. + /// public bool Active { get; set; } + + /// + /// Percentage discount applied to the subscription (e.g., 20.0 for 20% off). + /// Null if this is an amount-based discount. + /// public decimal? PercentOff { get; set; } - public List AppliesTo { get; set; } + + /// + /// Fixed amount discount in USD (e.g., 14.00 for $14 off). + /// Converted from Stripe's cent-based values (1400 cents → $14.00). + /// Null if this is a percentage-based discount. + /// + public decimal? AmountOff { get; set; } + + /// + /// List of Stripe product IDs that this discount applies to (e.g., ["prod_premium", "prod_families"]). + /// + /// Null: discount applies to all products with no restrictions (AppliesTo not specified in Stripe). + /// Empty list: discount restricted to zero products (edge case - AppliesTo.Products = [] in Stripe). + /// Non-empty list: discount applies only to the specified product IDs. + /// + /// + public IReadOnlyList? AppliesTo { get; set; } } public class BillingSubscription { public BillingSubscription(Subscription sub) { - Status = sub.Status; - TrialStartDate = sub.TrialStart; - TrialEndDate = sub.TrialEnd; - PeriodStartDate = sub.CurrentPeriodStart; - PeriodEndDate = sub.CurrentPeriodEnd; - CancelledDate = sub.CanceledAt; - CancelAtEndDate = sub.CancelAtPeriodEnd; - Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; - if (sub.Items?.Data != null) + Status = sub?.Status; + TrialStartDate = sub?.TrialStart; + TrialEndDate = sub?.TrialEnd; + var currentPeriod = sub?.GetCurrentPeriod(); + if (currentPeriod != null) + { + var (start, end) = currentPeriod.Value; + PeriodStartDate = start; + PeriodEndDate = end; + } + CancelledDate = sub?.CanceledAt; + CancelAtEndDate = sub?.CancelAtPeriodEnd ?? false; + var status = sub?.Status; + Cancelled = status == "canceled" || status == "unpaid" || status == "incomplete_expired"; + if (sub?.Items?.Data != null) { Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } - CollectionMethod = sub.CollectionMethod; - GracePeriod = sub.CollectionMethod == "charge_automatically" + CollectionMethod = sub?.CollectionMethod; + GracePeriod = sub?.CollectionMethod == "charge_automatically" ? 14 : 30; } @@ -58,10 +125,10 @@ public class SubscriptionInfo public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; public DateTime? CancelledDate { get; set; } public bool CancelAtEndDate { get; set; } - public string Status { get; set; } + public string? Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); - public string CollectionMethod { get; set; } + public string? CollectionMethod { get; set; } public DateTime? SuspensionDate { get; set; } public DateTime? UnpaidPeriodEndDate { get; set; } public int GracePeriod { get; set; } @@ -74,7 +141,7 @@ public class SubscriptionInfo { ProductId = item.Plan.ProductId; Name = item.Plan.Nickname; - Amount = item.Plan.Amount.GetValueOrDefault() / 100M; + Amount = ConvertFromStripeMinorUnits(item.Plan.Amount) ?? 0; Interval = item.Plan.Interval; if (item.Metadata != null) @@ -84,15 +151,15 @@ public class SubscriptionInfo } Quantity = (int)item.Quantity; - SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); + SponsoredSubscriptionItem = item.Plan != null && SponsoredPlans.All.Any(p => p.StripePlanId == item.Plan.Id); } public bool AddonSubscriptionItem { get; set; } - public string ProductId { get; set; } - public string Name { get; set; } + public string? ProductId { get; set; } + public string? Name { get; set; } public decimal Amount { get; set; } public int Quantity { get; set; } - public string Interval { get; set; } + public string? Interval { get; set; } public bool SponsoredSubscriptionItem { get; set; } } } @@ -103,7 +170,7 @@ public class SubscriptionInfo public BillingUpcomingInvoice(Invoice inv) { - Amount = inv.AmountDue / 100M; + Amount = ConvertFromStripeMinorUnits(inv.AmountDue) ?? 0; Date = inv.Created; } diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index 028fcad80b..7c23c9b73c 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -50,6 +50,7 @@ public abstract class SubscriptionUpdate protected static bool IsNonSeatBasedPlan(StaticStore.Plan plan) => plan.Type is >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 + or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually or PlanType.TeamsStarter2023 or PlanType.TeamsStarter; diff --git a/src/Core/Models/Data/UserKdfInformation.cs b/src/Core/Models/Data/UserKdfInformation.cs index 14f525bb82..0e5696e581 100644 --- a/src/Core/Models/Data/UserKdfInformation.cs +++ b/src/Core/Models/Data/UserKdfInformation.cs @@ -4,8 +4,8 @@ namespace Bit.Core.Models.Data; public class UserKdfInformation { - public KdfType Kdf { get; set; } - public int KdfIterations { get; set; } + public required KdfType Kdf { get; set; } + public required int KdfIterations { get; set; } public int? KdfMemory { get; set; } public int? KdfParallelism { get; set; } } diff --git a/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs index 5faf550e60..5eabd5ba2c 100644 --- a/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs +++ b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs @@ -9,4 +9,5 @@ public class DefaultEmailOtpViewModel : BaseMailModel public string? TheDate { get; set; } public string? TheTime { get; set; } public string? TimeZone { get; set; } + public string? Expiry { get; set; } } diff --git a/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs b/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs new file mode 100644 index 0000000000..b852d24ec9 --- /dev/null +++ b/src/Core/Models/Mail/Auth/OrganizationWelcomeEmailViewModel.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Models.Mail.Auth; + +public class OrganizationWelcomeEmailViewModel : BaseMailModel +{ + public required string OrganizationName { get; set; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs new file mode 100644 index 0000000000..e3aff02f5d --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.cs @@ -0,0 +1,16 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; + +public class Families2019RenewalMailView : BaseMailView +{ + public required string BaseMonthlyRenewalPrice { get; set; } + public required string BaseAnnualRenewalPrice { get; set; } + public required string DiscountedAnnualRenewalPrice { get; set; } + public required string DiscountAmount { get; set; } +} + +public class Families2019RenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Families renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs new file mode 100644 index 0000000000..227613999b --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -0,0 +1,584 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + +
    Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually + at {{BaseAnnualRenewalPrice}} + tax.
    + +
    + +
    As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
    + +
    + +
    Questions? Contact + support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs new file mode 100644 index 0000000000..88d64f9acf --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -0,0 +1,7 @@ +Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually +at {{BaseAnnualRenewalPrice}} + tax. + +As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs new file mode 100644 index 0000000000..eb7bef4322 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.cs @@ -0,0 +1,13 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; + +public class Families2020RenewalMailView : BaseMailView +{ + public required string MonthlyRenewalPrice { get; set; } +} + +public class Families2020RenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Families renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs new file mode 100644 index 0000000000..ac6b80993c --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs @@ -0,0 +1,619 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +

    + Your Bitwarden Families renewal is updating +

    + +
    + +
    + + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + +
    Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    Questions? Contact support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs new file mode 100644 index 0000000000..002a48cf10 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.text.hbs @@ -0,0 +1,3 @@ +Your Bitwarden Families subscription renews in 15 days. The price is updating to {{MonthlyRenewalPrice}}/month, billed annually. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs new file mode 100644 index 0000000000..e231a44467 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs @@ -0,0 +1,15 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.Billing.Renewal.Premium; + +public class PremiumRenewalMailView : BaseMailView +{ + public required string BaseMonthlyRenewalPrice { get; set; } + public required string DiscountedMonthlyRenewalPrice { get; set; } + public required string DiscountAmount { get; set; } +} + +public class PremiumRenewalMail : BaseMail +{ + public override string Subject { get => "Your Bitwarden Premium renewal is updating"; } +} diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs new file mode 100644 index 0000000000..a6b2fda0f7 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -0,0 +1,583 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + + + + + + + +
    + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + +
    + +
    + + +
    + +
    + + + + + +
    + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + + + + + +
    + +
    Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
    + +
    + +
    Questions? Contact + support@bitwarden.com
    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + + + + +
    + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + +
    + +

    + Learn more about Bitwarden +

    + Find user guides, product documentation, and videos on the + Bitwarden Help Center.
    + +
    + +
    + + + +
    + + + + + + + + + +
    + +
    + + +
    + +
    + + + +
    + +
    + + + + + + + + + +
    + + + + + + + +
    + + +
    + + + + + + + + + + + + + +
    + + + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + + + + + + + + +
    + + + + + + +
    + + + +
    +
    + + + +
    + +

    + © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

    +

    + Always confirm you are on a trusted Bitwarden domain before logging + in:
    + bitwarden.com | + Learn why we include this +

    + +
    + +
    + + +
    + +
    + + + + + +
    + + + + \ No newline at end of file diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs new file mode 100644 index 0000000000..41300d0f96 --- /dev/null +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -0,0 +1,6 @@ +Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. + +As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + +Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index c4ae1e2858..ec39c495aa 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -1,4 +1,5 @@ -using Bit.Core.Enums; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; using Bit.Core.NotificationCenter.Enums; namespace Bit.Core.Models; @@ -97,3 +98,15 @@ public class ProviderBankAccountVerifiedPushNotification public Guid ProviderId { get; set; } public Guid AdminId { get; set; } } + +public class LogOutPushNotification +{ + public Guid UserId { get; set; } + public PushNotificationLogOutReason? Reason { get; set; } +} + +public class SyncPolicyPushNotification +{ + public Guid OrganizationId { get; set; } + public required Policy Policy { get; set; } +} diff --git a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs deleted file mode 100644 index 34662ecdbb..0000000000 --- a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs +++ /dev/null @@ -1,51 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Models.BitStripe; - -// Stripe's SubscriptionListOptions model has a complex input for date filters. -// It expects a dictionary, and has lots of validation rules around what can have a value and what can't. -// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. -// ___ -// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. -public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions -{ - public DateTime? CurrentPeriodEndDate { get; set; } - public string CurrentPeriodEndRange { get; set; } = "lt"; - public bool SelectAll { get; set; } - public new Stripe.DateRangeOptions CurrentPeriodEnd - { - get - { - return CurrentPeriodEndDate.HasValue ? - new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - } : - null; - } - } - - public Stripe.SubscriptionListOptions ToStripeApiOptions() - { - var stripeApiOptions = (Stripe.SubscriptionListOptions)this; - - if (SelectAll) - { - stripeApiOptions.EndingBefore = null; - stripeApiOptions.StartingAfter = null; - } - - if (CurrentPeriodEndDate.HasValue) - { - stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - }; - } - - return stripeApiOptions; - } -} diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index da05bc929c..b502cc6e4e 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.OrganizationAuth; using Bit.Core.AdminConsole.OrganizationAuth.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.Groups; using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Import; @@ -11,8 +12,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Authorization; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -42,6 +45,9 @@ using Microsoft.AspNetCore.DataProtection; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using V1_RevokeUsersCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; +using V2_RevokeUsersCommand = Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + namespace Bit.Core.OrganizationFeatures; public static class OrganizationServiceCollectionExtensions @@ -85,6 +91,7 @@ public static class OrganizationServiceCollectionExtensions private static void AddOrganizationUpdateCommands(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); } private static void AddOrganizationEnableCommands(this IServiceCollection services) => @@ -129,13 +136,20 @@ public static class OrganizationServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); - services.AddScoped(); services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); + + services.AddScoped(); + + services.AddScoped(); + services.AddScoped(); } private static void AddOrganizationApiKeyCommandsQueries(this IServiceCollection services) @@ -190,6 +204,7 @@ public static class OrganizationServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs index 2756f8930b..566c723692 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/CloudSyncSponsorshipsCommand.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -7,7 +8,6 @@ using Bit.Core.Models.Data.Organizations.OrganizationSponsorships; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; @@ -54,10 +54,9 @@ public class CloudSyncSponsorshipsCommand : ICloudSyncSponsorshipsCommand foreach (var selfHostedSponsorship in sponsorshipsData) { - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(selfHostedSponsorship.PlanSponsorshipType)?.SponsoringProductTierType; + var requiredSponsoringProductType = SponsoredPlans.Get(selfHostedSponsorship.PlanSponsorshipType).SponsoringProductTierType; var sponsoringOrgProductTier = sponsoringOrg.PlanType.GetProductTier(); - if (requiredSponsoringProductType == null - || sponsoringOrgProductTier != requiredSponsoringProductType.Value) + if (sponsoringOrgProductTier != requiredSponsoringProductType) { continue; // prevent unsupported sponsorships } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs index a54106481c..6d60f05b2a 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommand.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; @@ -13,9 +13,9 @@ public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand { private readonly IOrganizationSponsorshipRepository _organizationSponsorshipRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; - public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IPaymentService paymentService) + public SetUpSponsorshipCommand(IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, IStripePaymentService paymentService) { _organizationSponsorshipRepository = organizationSponsorshipRepository; _organizationRepository = organizationRepository; @@ -50,11 +50,10 @@ public class SetUpSponsorshipCommand : ISetUpSponsorshipCommand } // Check org to sponsor's product type - var requiredSponsoredProductType = StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value)?.SponsoredProductTierType; + var requiredSponsoredProductType = SponsoredPlans.Get(sponsorship.PlanSponsorshipType.Value).SponsoredProductTierType; var sponsoredOrganizationProductTier = sponsoredOrganization.PlanType.GetProductTier(); - if (requiredSponsoredProductType == null || - sponsoredOrganizationProductTier != requiredSponsoredProductType.Value) + if (sponsoredOrganizationProductTier != requiredSponsoredProductType) { throw new BadRequestException("Can only redeem sponsorship offer on families organizations."); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs index dcda77acea..4b983317c9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/ValidateSponsorshipCommand.cs @@ -3,6 +3,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -13,14 +15,14 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnte public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSponsorshipCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; public ValidateSponsorshipCommand( IOrganizationSponsorshipRepository organizationSponsorshipRepository, IOrganizationRepository organizationRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger) : base(organizationSponsorshipRepository, organizationRepository) { @@ -95,7 +97,7 @@ public class ValidateSponsorshipCommand : CancelSponsorshipCommand, IValidateSpo return false; } - var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(existingSponsorship.PlanSponsorshipType.Value); + var sponsoredPlan = SponsoredPlans.Get(existingSponsorship.PlanSponsorshipType.Value); var sponsoringOrganization = await _organizationRepository .GetByIdAsync(existingSponsorship.SponsoringOrganizationId.Value); diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs index a729937fad..ab4b17d215 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CreateSponsorshipCommand.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Models; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -7,7 +8,6 @@ using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; namespace Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; @@ -34,11 +34,10 @@ public class CreateSponsorshipCommand( throw new BadRequestException("Cannot offer a Families Organization Sponsorship to yourself. Choose a different email."); } - var requiredSponsoringProductType = StaticStore.GetSponsoredPlan(sponsorshipType)?.SponsoringProductTierType; + var requiredSponsoringProductType = SponsoredPlans.Get(sponsorshipType).SponsoringProductTierType; var sponsoringOrgProductTier = sponsoringOrganization.PlanType.GetProductTier(); - if (requiredSponsoringProductType == null || - sponsoringOrgProductTier != requiredSponsoringProductType.Value) + if (sponsoringOrgProductTier != requiredSponsoringProductType) { throw new BadRequestException("Specified Organization cannot sponsor other organizations."); } diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs index 9a995a9cf0..965e0cf2a9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs @@ -62,7 +62,7 @@ public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISel .ToDictionary(i => i.SponsoringOrganizationUserId); if (!organizationSponsorshipsDict.Any()) { - _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); + _logger.LogInformation("No existing sponsorships to sync for organization {organizationId}", organizationId); return; } var syncedSponsorships = new List(); diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs index a0ce7c03b9..25b84fe989 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/AddSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; @@ -12,13 +13,13 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class AddSecretsManagerSubscriptionCommand : IAddSecretsManagerSubscriptionCommand { - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IOrganizationService _organizationService; private readonly IProviderRepository _providerRepository; private readonly IPricingClient _pricingClient; public AddSecretsManagerSubscriptionCommand( - IPaymentService paymentService, + IStripePaymentService paymentService, IOrganizationService organizationService, IProviderRepository providerRepository, IPricingClient pricingClient) diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs index 739dca5228..baf2616a53 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -18,7 +19,7 @@ namespace Bit.Core.OrganizationFeatures.OrganizationSubscriptions; public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubscriptionCommand { private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IMailService _mailService; private readonly ILogger _logger; private readonly IServiceAccountRepository _serviceAccountRepository; @@ -29,7 +30,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs public UpdateSecretsManagerSubscriptionCommand( IOrganizationUserRepository organizationUserRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IMailService mailService, ILogger logger, IServiceAccountRepository serviceAccountRepository, @@ -226,7 +227,11 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs // Check minimum seats currently in use by the organization if (organization.SmSeats.Value > update.SmSeats.Value) { + // Retrieve the number of currently occupied Secrets Manager seats for the organization. var occupiedSeats = await _organizationUserRepository.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id); + + // Check if the occupied number of seats exceeds the updated seat count. + // If so, throw an exception indicating that the subscription cannot be decreased below the current usage. if (occupiedSeats > update.SmSeats.Value) { throw new BadRequestException($"{occupiedSeats} users are currently occupying Secrets Manager seats. " + @@ -412,7 +417,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs } /// - /// Requests the number of Secret Manager seats and service accounts are currently used by the organization + /// Requests the number of Secret Manager seats and service accounts currently used by the organization /// /// The id of the organization /// A tuple containing the occupied seats and the occupied service account counts diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs index 2b39e6cca6..4ad63bd8d7 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.OrganizationConnectionConfigs; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -11,6 +12,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -26,7 +28,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IOrganizationConnectionRepository _organizationConnectionRepository; @@ -41,7 +43,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand IOrganizationUserRepository organizationUserRepository, ICollectionRepository collectionRepository, IGroupRepository groupRepository, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, ISsoConfigRepository ssoConfigRepository, IOrganizationConnectionRepository organizationConnectionRepository, @@ -254,30 +256,21 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand organization.UseApi = newPlan.HasApi; organization.SelfHost = newPlan.HasSelfHost; organization.UsePolicies = newPlan.HasPolicies; - organization.MaxStorageGb = !newPlan.PasswordManager.BaseStorageGb.HasValue - ? (short?)null - : (short)(newPlan.PasswordManager.BaseStorageGb.Value + upgrade.AdditionalStorageGb); - organization.UseGroups = newPlan.HasGroups; - organization.UseDirectory = newPlan.HasDirectory; - organization.UseEvents = newPlan.HasEvents; - organization.UseTotp = newPlan.HasTotp; - organization.Use2fa = newPlan.Has2fa; - organization.UseApi = newPlan.HasApi; + organization.MaxStorageGb = (short)(newPlan.PasswordManager.BaseStorageGb + upgrade.AdditionalStorageGb); organization.UseSso = newPlan.HasSso; organization.UseOrganizationDomains = newPlan.HasOrganizationDomains; organization.UseKeyConnector = newPlan.HasKeyConnector ? organization.UseKeyConnector : false; organization.UseScim = newPlan.HasScim; organization.UseResetPassword = newPlan.HasResetPassword; - organization.SelfHost = newPlan.HasSelfHost; organization.UsersGetPremium = newPlan.UsersGetPremium || upgrade.PremiumAccessAddon; organization.UseCustomPermissions = newPlan.HasCustomPermissions; organization.Plan = newPlan.Name; organization.Enabled = success; - organization.PublicKey = upgrade.PublicKey; - organization.PrivateKey = upgrade.PrivateKey; organization.UsePasswordManager = true; organization.UseSecretsManager = upgrade.UseSecretsManager; + organization.BackfillPublicPrivateKeys(upgrade.Keys); + if (upgrade.UseSecretsManager) { organization.SmSeats = newPlan.SecretsManager.BaseSeats + upgrade.AdditionalSmSeats.GetValueOrDefault(); diff --git a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs b/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs deleted file mode 100644 index 6b76bc35f0..0000000000 --- a/src/Core/PhishingDomainFeatures/AzurePhishingDomainStorageService.cs +++ /dev/null @@ -1,95 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text; -using Azure.Storage.Blobs; -using Azure.Storage.Blobs.Models; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -public class AzurePhishingDomainStorageService -{ - private const string _containerName = "phishingdomains"; - private const string _domainsFileName = "domains.txt"; - private const string _checksumFileName = "checksum.txt"; - - private readonly BlobServiceClient _blobServiceClient; - private readonly ILogger _logger; - private BlobContainerClient _containerClient; - - public AzurePhishingDomainStorageService( - GlobalSettings globalSettings, - ILogger logger) - { - _blobServiceClient = new BlobServiceClient(globalSettings.Storage.ConnectionString); - _logger = logger; - } - - public async Task> GetDomainsAsync() - { - await InitAsync(); - - var blobClient = _containerClient.GetBlobClient(_domainsFileName); - if (!await blobClient.ExistsAsync()) - { - return []; - } - - var response = await blobClient.DownloadAsync(); - using var streamReader = new StreamReader(response.Value.Content); - var content = await streamReader.ReadToEndAsync(); - - return [.. content - .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) - .Select(line => line.Trim()) - .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith('#'))]; - } - - public async Task GetChecksumAsync() - { - await InitAsync(); - - var blobClient = _containerClient.GetBlobClient(_checksumFileName); - if (!await blobClient.ExistsAsync()) - { - return string.Empty; - } - - var response = await blobClient.DownloadAsync(); - using var streamReader = new StreamReader(response.Value.Content); - return (await streamReader.ReadToEndAsync()).Trim(); - } - - public async Task UpdateDomainsAsync(IEnumerable domains, string checksum) - { - await InitAsync(); - - var domainsContent = string.Join(Environment.NewLine, domains); - var domainsStream = new MemoryStream(Encoding.UTF8.GetBytes(domainsContent)); - var domainsBlobClient = _containerClient.GetBlobClient(_domainsFileName); - - await domainsBlobClient.UploadAsync(domainsStream, new BlobUploadOptions - { - HttpHeaders = new BlobHttpHeaders { ContentType = "text/plain" } - }, CancellationToken.None); - - var checksumStream = new MemoryStream(Encoding.UTF8.GetBytes(checksum)); - var checksumBlobClient = _containerClient.GetBlobClient(_checksumFileName); - - await checksumBlobClient.UploadAsync(checksumStream, new BlobUploadOptions - { - HttpHeaders = new BlobHttpHeaders { ContentType = "text/plain" } - }, CancellationToken.None); - } - - private async Task InitAsync() - { - if (_containerClient is null) - { - _containerClient = _blobServiceClient.GetBlobContainerClient(_containerName); - await _containerClient.CreateIfNotExistsAsync(); - } - } -} diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs deleted file mode 100644 index 420948e310..0000000000 --- a/src/Core/PhishingDomainFeatures/CloudPhishingDomainDirectQuery.cs +++ /dev/null @@ -1,100 +0,0 @@ -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -/// -/// Implementation of ICloudPhishingDomainQuery for cloud environments -/// that directly calls the external phishing domain source -/// -public class CloudPhishingDomainDirectQuery : ICloudPhishingDomainQuery -{ - private readonly IGlobalSettings _globalSettings; - private readonly IHttpClientFactory _httpClientFactory; - private readonly ILogger _logger; - - public CloudPhishingDomainDirectQuery( - IGlobalSettings globalSettings, - IHttpClientFactory httpClientFactory, - ILogger logger) - { - _globalSettings = globalSettings; - _httpClientFactory = httpClientFactory; - _logger = logger; - } - - public async Task> GetPhishingDomainsAsync() - { - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.UpdateUrl)) - { - throw new InvalidOperationException("Phishing domain update URL is not configured."); - } - - var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); - var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.UpdateUrl); - response.EnsureSuccessStatusCode(); - - var content = await response.Content.ReadAsStringAsync(); - return ParseDomains(content); - } - - /// - /// Gets the SHA256 checksum of the remote phishing domains list - /// - /// The SHA256 checksum as a lowercase hex string - public async Task GetRemoteChecksumAsync() - { - if (string.IsNullOrWhiteSpace(_globalSettings.PhishingDomain?.ChecksumUrl)) - { - _logger.LogWarning("Phishing domain checksum URL is not configured."); - return string.Empty; - } - - try - { - var httpClient = _httpClientFactory.CreateClient("PhishingDomains"); - var response = await httpClient.GetAsync(_globalSettings.PhishingDomain.ChecksumUrl); - response.EnsureSuccessStatusCode(); - - var content = await response.Content.ReadAsStringAsync(); - return ParseChecksumResponse(content); - } - catch (Exception ex) - { - _logger.LogError(ex, "Error retrieving phishing domain checksum from {Url}", - _globalSettings.PhishingDomain.ChecksumUrl); - return string.Empty; - } - } - - /// - /// Parses a checksum response in the format "hash *filename" - /// - private static string ParseChecksumResponse(string checksumContent) - { - if (string.IsNullOrWhiteSpace(checksumContent)) - { - return string.Empty; - } - - // Format is typically "hash *filename" - var parts = checksumContent.Split(' ', 2); - - return parts.Length > 0 ? parts[0].Trim() : string.Empty; - } - - private static List ParseDomains(string content) - { - if (string.IsNullOrWhiteSpace(content)) - { - return []; - } - - return content - .Split(new[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries) - .Select(line => line.Trim()) - .Where(line => !string.IsNullOrWhiteSpace(line) && !line.StartsWith("#")) - .ToList(); - } -} diff --git a/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs b/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs deleted file mode 100644 index 6b0027062c..0000000000 --- a/src/Core/PhishingDomainFeatures/CloudPhishingDomainRelayQuery.cs +++ /dev/null @@ -1,69 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.PhishingDomainFeatures.Interfaces; -using Bit.Core.Services; -using Bit.Core.Settings; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.PhishingDomainFeatures; - -/// -/// Implementation of ICloudPhishingDomainQuery for self-hosted environments -/// that relays the request to the Bitwarden cloud API -/// -public class CloudPhishingDomainRelayQuery : BaseIdentityClientService, ICloudPhishingDomainQuery -{ - private readonly IGlobalSettings _globalSettings; - - public CloudPhishingDomainRelayQuery( - IHttpClientFactory httpFactory, - IGlobalSettings globalSettings, - ILogger logger) - : base( - httpFactory, - globalSettings.Installation.ApiUri, - globalSettings.Installation.IdentityUri, - "api.licensing", - $"installation.{globalSettings.Installation.Id}", - globalSettings.Installation.Key, - logger) - { - _globalSettings = globalSettings; - } - - public async Task> GetPhishingDomainsAsync() - { - if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) - { - throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); - } - - var result = await SendAsync(HttpMethod.Get, "phishing-domains", null, true); - return result?.ToList() ?? new List(); - } - - /// - /// Gets the SHA256 checksum of the remote phishing domains list - /// - /// The SHA256 checksum as a lowercase hex string - public async Task GetRemoteChecksumAsync() - { - if (!_globalSettings.SelfHosted || !_globalSettings.EnableCloudCommunication) - { - throw new InvalidOperationException("This query is only for self-hosted installations with cloud communication enabled."); - } - - try - { - // For self-hosted environments, we get the checksum from the Bitwarden cloud API - var result = await SendAsync(HttpMethod.Get, "phishing-domains/checksum", null, true); - return result ?? string.Empty; - } - catch (Exception ex) - { - _logger.LogError(ex, "Error retrieving phishing domain checksum from Bitwarden cloud API"); - return string.Empty; - } - } -} diff --git a/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs b/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs deleted file mode 100644 index dac91747f7..0000000000 --- a/src/Core/PhishingDomainFeatures/Interfaces/ICloudPhishingDomainQuery.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.PhishingDomainFeatures.Interfaces; - -public interface ICloudPhishingDomainQuery -{ - Task> GetPhishingDomainsAsync(); - Task GetRemoteChecksumAsync(); -} diff --git a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs similarity index 99% rename from src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs index 344c2e712d..ade289be8f 100644 --- a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs @@ -9,7 +9,7 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable { diff --git a/src/Core/Services/IMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs similarity index 73% rename from src/Core/Services/IMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs index 9247367221..1f2a024c34 100644 --- a/src/Core/Services/IMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public interface IMailDeliveryService { diff --git a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs similarity index 99% rename from src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs index 04eda42d22..c78b107084 100644 --- a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs @@ -7,7 +7,7 @@ using MailKit.Net.Smtp; using Microsoft.Extensions.Logging; using MimeKit; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class MailKitSmtpMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs similarity index 96% rename from src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs index e088410967..1e34e1f842 100644 --- a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs @@ -3,7 +3,7 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class MultiServiceMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs similarity index 82% rename from src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs index 96b97b14f5..d8194ffb18 100644 --- a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class NoopMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs similarity index 98% rename from src/Core/Services/Implementations/SendGridMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs index 773f87931d..10afcc539a 100644 --- a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs @@ -6,7 +6,7 @@ using Microsoft.Extensions.Logging; using SendGrid; using SendGrid.Helpers.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable { diff --git a/src/Core/Services/Implementations/AzureQueueMailService.cs b/src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs similarity index 91% rename from src/Core/Services/Implementations/AzureQueueMailService.cs rename to src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs index 92d6fd17bb..c88090a954 100644 --- a/src/Core/Services/Implementations/AzureQueueMailService.cs +++ b/src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs @@ -1,10 +1,10 @@ using Azure.Storage.Queues; using Bit.Core.Models.Mail; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services; - +namespace Bit.Core.Platform.Mail.Enqueuing; public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService { public AzureQueueMailService(GlobalSettings globalSettings) : base( diff --git a/src/Core/Services/Implementations/BlockingMailQueueService.cs b/src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs similarity index 91% rename from src/Core/Services/Implementations/BlockingMailQueueService.cs rename to src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs index 0323b09af7..e75874af16 100644 --- a/src/Core/Services/Implementations/BlockingMailQueueService.cs +++ b/src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs @@ -1,7 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - +namespace Bit.Core.Platform.Mail.Enqueuing; public class BlockingMailEnqueuingService : IMailEnqueuingService { public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) diff --git a/src/Core/Services/IMailEnqueuingService.cs b/src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs similarity index 86% rename from src/Core/Services/IMailEnqueuingService.cs rename to src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs index 19dc33f19e..d74f9160e4 100644 --- a/src/Core/Services/IMailEnqueuingService.cs +++ b/src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Enqueuing; public interface IMailEnqueuingService { diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Platform/Mail/HandlebarsMailService.cs similarity index 94% rename from src/Core/Services/Implementations/HandlebarsMailService.cs rename to src/Core/Platform/Mail/HandlebarsMailService.cs index 75e0c78702..d57ca400fd 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Platform/Mail/HandlebarsMailService.cs @@ -19,6 +19,8 @@ using Bit.Core.Models.Mail.Auth; using Bit.Core.Models.Mail.Billing; using Bit.Core.Models.Mail.FamiliesForEnterprise; using Bit.Core.Models.Mail.Provider; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; using Bit.Core.SecretsManager.Models.Mail; using Bit.Core.Settings; using Bit.Core.Utilities; @@ -28,8 +30,9 @@ using HandlebarsDotNet; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Services.Mail; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public class HandlebarsMailService : IMailService { private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; @@ -75,7 +78,7 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendRegistrationVerificationEmailAsync(string email, string token) + public async Task SendRegistrationVerificationEmailAsync(string email, string token, string? fromMarketing) { var message = CreateDefaultMessage("Verify Your Email", email); var model = new RegisterVerifyEmail @@ -83,7 +86,8 @@ public class HandlebarsMailService : IMailService Token = WebUtility.UrlEncode(token), Email = WebUtility.UrlEncode(email), WebVaultUrl = _globalSettings.BaseServiceUri.Vault, - SiteName = _globalSettings.SiteName + SiteName = _globalSettings.SiteName, + FromMarketing = WebUtility.UrlEncode(fromMarketing), }; await AddMessageContentAsync(message, "Auth.RegistrationVerifyEmail", model); message.MetaData.Add("SendGridBypassListManagement", true); @@ -224,6 +228,27 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) + { + var message = CreateDefaultMessage(subject, email); + var requestDateTime = DateTime.UtcNow; + var model = new DefaultEmailOtpViewModel + { + Token = token, + Expiry = "5", // This should be configured through the OTPDefaultTokenProviderOptions but for now we will hardcode it to 5 minutes. + TheDate = requestDateTime.ToLongDateString(), + TheTime = requestDateTime.ToShortTimeString(), + TimeZone = _utcTimeZoneDisplay, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmailv2", model); + message.MetaData.Add("SendGridBypassListManagement", true); + // TODO - PM-25380 change to string constant + message.Category = "SendEmailOtp"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { // Check if we've sent this email within the last hour @@ -400,6 +425,8 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + // TODO: DO NOT move to IMailer implementation: PM-27852 + [Obsolete("Use SendIndividualUserWelcomeEmailAsync instead")] public async Task SendWelcomeEmailAsync(User user) { var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); @@ -413,6 +440,50 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendIndividualUserWelcomeEmailAsync(User user) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new BaseMailModel + { + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-individual-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new OrganizationWelcomeEmailViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-org-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + + // TODO: Move to IMailer implementation: PM-27852 + public async Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName) + { + var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); + var model = new OrganizationWelcomeEmailViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(familyOrganizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "MJML.Auth.Onboarding.welcome-family-user", model); + message.Category = "Welcome"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendTrialInitiationEmailAsync(string userEmail) { var message = CreateDefaultMessage("Welcome to Bitwarden; 3 steps to get started!", userEmail); @@ -653,7 +724,7 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public async Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { var message = CreateDefaultMessage("Your admin has initiated account recovery", email); var model = new AdminResetPasswordViewModel() diff --git a/src/Core/Services/IMailService.cs b/src/Core/Platform/Mail/IMailService.cs similarity index 78% rename from src/Core/Services/IMailService.cs rename to src/Core/Platform/Mail/IMailService.cs index 6e61c4f8dd..e21e1a010f 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Platform/Mail/IMailService.cs @@ -1,9 +1,8 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations; @@ -13,11 +12,33 @@ using Core.Auth.Enums; namespace Bit.Core.Services; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public interface IMailService { + [Obsolete("Use SendIndividualUserWelcomeEmailAsync instead")] Task SendWelcomeEmailAsync(User user); + /// + /// Email sent to users who have created a new account as an individual user. + /// + /// The new User + /// Task + Task SendIndividualUserWelcomeEmailAsync(User user); + /// + /// Email sent to users who have been confirmed to an organization. + /// + /// The User + /// The Organization user is being added to + /// Task + Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName); + /// + /// Email sent to users who have been confirmed to a free or families organization. + /// + /// The User + /// The Families Organization user is being added to + /// Task + Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName); Task SendVerifyEmailEmailAsync(string email, Guid userId, string token); - Task SendRegistrationVerificationEmailAsync(string email, string token); + Task SendRegistrationVerificationEmailAsync(string email, string token, string? fromMarketing); Task SendTrialInitiationSignupEmailAsync( bool isExistingUser, string email, @@ -31,6 +52,16 @@ public interface IMailService Task SendChangeEmailEmailAsync(string newEmailAddress, string token); Task SendTwoFactorEmailAsync(string email, string accountEmail, string token, string deviceIp, string deviceType, TwoFactorEmailPurpose purpose); Task SendSendEmailOtpEmailAsync(string email, string token, string subject); + /// + /// has a default expiry of 5 minutes so we set the expiry to that value int he view model. + /// Sends OTP code token to the specified email address. + /// will replace when MJML templates are fully accepted. + /// + /// Email address to send the OTP to + /// Otp code token + /// subject line of the email + /// Task + Task SendSendEmailOtpEmailv2Async(string email, string token, string subject); Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType type, DateTime utcNow, string ip); Task SendNoMasterPasswordHintEmailAsync(string email); Task SendMasterPasswordHintEmailAsync(string email, string hint); @@ -81,7 +112,7 @@ public interface IMailService Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); - Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); + Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName); Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email); Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); diff --git a/src/Core/Platform/Mail/Mailer/BaseMail.cs b/src/Core/Platform/Mail/Mailer/BaseMail.cs new file mode 100644 index 0000000000..0fd6b79aba --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/BaseMail.cs @@ -0,0 +1,54 @@ +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// BaseMail describes a model for emails. It contains metadata about the email such as recipients, +/// subject, and an optional category for processing at the upstream email delivery service. +/// +/// Each BaseMail must have a view model that inherits from BaseMailView. The view model is used to +/// generate the text part and HTML body. +/// +public abstract class BaseMail where TView : BaseMailView +{ + /// + /// Email recipients. + /// + public required IEnumerable ToEmails { get; set; } + + /// + /// The subject of the email. + /// + public abstract string Subject { get; } + + /// + /// An optional category for processing at the upstream email delivery service. + /// + public string? Category { get; } + + /// + /// Allows you to override and ignore the suppression list for this email. + /// + /// Warning: This should be used with caution, valid reasons are primarily account recovery, email OTP. + /// + public virtual bool IgnoreSuppressList { get; } = false; + + /// + /// View model for the email body. + /// + public required TView View { get; set; } +} + +/// +/// Each MailView consists of two body parts: a text part and an HTML part and the filename must be +/// relative to the viewmodel and match the following pattern: +/// - `{ClassName}.html.hbs` for the HTML part +/// - `{ClassName}.text.hbs` for the text part +/// +public abstract class BaseMailView +{ + /// + /// Current year. + /// + public string CurrentYear => DateTime.UtcNow.Year.ToString(); +} diff --git a/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs b/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs new file mode 100644 index 0000000000..baba5b8015 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs @@ -0,0 +1,132 @@ +#nullable enable +using System.Collections.Concurrent; +using System.Reflection; +using Bit.Core.Settings; +using HandlebarsDotNet; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.Platform.Mail.Mailer; +public class HandlebarMailRenderer : IMailRenderer +{ + /// + /// Lazy-initialized Handlebars instance. Thread-safe and ensures initialization occurs only once. + /// + private readonly Lazy> _handlebarsTask; + + /// + /// Helper function that returns the handlebar instance. + /// + private Task GetHandlebars() => _handlebarsTask.Value; + + /// + /// This dictionary is used to cache compiled templates in a thread-safe manner. + /// + private readonly ConcurrentDictionary>>> _templateCache = new(); + + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + + public HandlebarMailRenderer(ILogger logger, GlobalSettings globalSettings) + { + _logger = logger; + _globalSettings = globalSettings; + + _handlebarsTask = new Lazy>(InitializeHandlebarsAsync, LazyThreadSafetyMode.ExecutionAndPublication); + } + + public async Task<(string html, string txt)> RenderAsync(BaseMailView model) + { + var html = await CompileTemplateAsync(model, "html"); + var txt = await CompileTemplateAsync(model, "text"); + + return (html, txt); + } + + private async Task CompileTemplateAsync(BaseMailView model, string type) + { + var templateName = $"{model.GetType().FullName}.{type}.hbs"; + var assembly = model.GetType().Assembly; + + // GetOrAdd is atomic - only one Lazy will be stored per templateName. + // The Lazy with ExecutionAndPublication ensures the compilation happens exactly once. + var lazyTemplate = _templateCache.GetOrAdd( + templateName, + key => new Lazy>>( + () => CompileTemplateInternalAsync(assembly, key), + LazyThreadSafetyMode.ExecutionAndPublication)); + + var template = await lazyTemplate.Value; + return template(model); + } + + private async Task> CompileTemplateInternalAsync(Assembly assembly, string templateName) + { + var source = await ReadSourceAsync(assembly, templateName); + var handlebars = await GetHandlebars(); + return handlebars.Compile(source); + } + + private async Task ReadSourceAsync(Assembly assembly, string template) + { + if (assembly.GetManifestResourceNames().All(f => f != template)) + { + throw new FileNotFoundException("Template not found: " + template); + } + + var diskSource = await ReadSourceFromDiskAsync(template); + if (!string.IsNullOrWhiteSpace(diskSource)) + { + return diskSource; + } + + await using var s = assembly.GetManifestResourceStream(template)!; + using var sr = new StreamReader(s); + return await sr.ReadToEndAsync(); + } + + private async Task ReadSourceFromDiskAsync(string template) + { + if (!_globalSettings.SelfHosted) + { + return null; + } + + try + { + var diskPath = Path.GetFullPath(Path.Combine(_globalSettings.MailTemplateDirectory, template)); + var baseDirectory = Path.GetFullPath(_globalSettings.MailTemplateDirectory); + + // Ensure the resolved path is within the configured directory + if (!diskPath.StartsWith(baseDirectory + Path.DirectorySeparatorChar, StringComparison.OrdinalIgnoreCase) && + !diskPath.Equals(baseDirectory, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning("Template path traversal attempt detected: {Template}", template); + return null; + } + + if (File.Exists(diskPath)) + { + var fileContents = await File.ReadAllTextAsync(diskPath); + return fileContents; + } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to read mail template from disk: {TemplateName}", template); + } + + return null; + } + + private async Task InitializeHandlebarsAsync() + { + var handlebars = Handlebars.Create(); + + // TODO: Do we still need layouts with MJML? + var assembly = typeof(HandlebarMailRenderer).Assembly; + var layoutSource = await ReadSourceAsync(assembly, "Bit.Core.MailTemplates.Handlebars.Layouts.Full.html.hbs"); + handlebars.RegisterTemplate("FullHtmlLayout", layoutSource); + + return handlebars; + } +} diff --git a/src/Core/Platform/Mail/Mailer/IMailRenderer.cs b/src/Core/Platform/Mail/Mailer/IMailRenderer.cs new file mode 100644 index 0000000000..7f392df479 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/IMailRenderer.cs @@ -0,0 +1,7 @@ +#nullable enable +namespace Bit.Core.Platform.Mail.Mailer; + +public interface IMailRenderer +{ + Task<(string html, string txt)> RenderAsync(BaseMailView model); +} diff --git a/src/Core/Platform/Mail/Mailer/IMailer.cs b/src/Core/Platform/Mail/Mailer/IMailer.cs new file mode 100644 index 0000000000..6dc3eec46f --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/IMailer.cs @@ -0,0 +1,15 @@ +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// Generic mailer interface for sending email messages. +/// +public interface IMailer +{ + /// + /// Sends an email message. + /// + /// + public Task SendEmail(BaseMail message) where T : BaseMailView; +} diff --git a/src/Core/Platform/Mail/Mailer/Mailer.cs b/src/Core/Platform/Mail/Mailer/Mailer.cs new file mode 100644 index 0000000000..f5e8d35d58 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/Mailer.cs @@ -0,0 +1,32 @@ +using Bit.Core.Models.Mail; +using Bit.Core.Platform.Mail.Delivery; + +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +public class Mailer(IMailRenderer renderer, IMailDeliveryService mailDeliveryService) : IMailer +{ + public async Task SendEmail(BaseMail message) where T : BaseMailView + { + var content = await renderer.RenderAsync(message.View); + + var metadata = new Dictionary(); + if (message.IgnoreSuppressList) + { + metadata.Add("SendGridBypassListManagement", true); + } + + var mailMessage = new MailMessage + { + ToEmails = message.ToEmails, + Subject = message.Subject, + MetaData = metadata, + HtmlContent = content.html, + TextContent = content.txt, + Category = message.Category, + }; + + await mailDeliveryService.SendEmailAsync(mailMessage); + } +} diff --git a/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs b/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs new file mode 100644 index 0000000000..cc56b3ec5a --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs @@ -0,0 +1,27 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// Extension methods for adding the Mailer feature to the service collection. +/// +public static class MailerServiceCollectionExtensions +{ + /// + /// Adds the Mailer services to the . + /// This includes the mail renderer and mailer for sending templated emails. + /// This method is safe to be run multiple times. + /// + /// The to add services to. + /// The for additional chaining. + public static IServiceCollection AddMailer(this IServiceCollection services) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + + return services; + } +} diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Platform/Mail/NoopMailService.cs similarity index 93% rename from src/Core/Services/NoopImplementations/NoopMailService.cs rename to src/Core/Platform/Mail/NoopMailService.cs index 7ec05bb1f9..7de48e4619 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Platform/Mail/NoopMailService.cs @@ -13,6 +13,7 @@ using Core.Auth.Enums; namespace Bit.Core.Services; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public class NoopMailService : IMailService { public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) @@ -25,7 +26,7 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendRegistrationVerificationEmailAsync(string email, string hint) + public Task SendRegistrationVerificationEmailAsync(string email, string hint, string? fromMarketing) { return Task.FromResult(0); } @@ -98,6 +99,11 @@ public class NoopMailService : IMailService return Task.FromResult(0); } + public Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) + { + return Task.FromResult(0); + } + public Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { return Task.FromResult(0); @@ -108,6 +114,20 @@ public class NoopMailService : IMailService return Task.FromResult(0); } + public Task SendIndividualUserWelcomeEmailAsync(User user) + { + return Task.FromResult(0); + } + + public Task SendOrganizationUserWelcomeEmailAsync(User user, string organizationName) + { + return Task.FromResult(0); + } + + public Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(User user, string familyOrganizationName) + { + return Task.FromResult(0); + } public Task SendVerifyDeleteEmailAsync(string email, Guid userId, string token) { return Task.FromResult(0); @@ -216,7 +236,7 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { return Task.FromResult(0); } diff --git a/src/Core/Platform/Mail/README.md b/src/Core/Platform/Mail/README.md new file mode 100644 index 0000000000..7a3b6b87c5 --- /dev/null +++ b/src/Core/Platform/Mail/README.md @@ -0,0 +1,226 @@ +# Mail Services +## `MailService` + +> [!WARNING] +> The `MailService` and its implementation in `HandlebarsMailService` has been deprecated in favor of the `Mailer` implementation. + +The `MailService` class manages **all** emails, and has multiple responsibilities, including formatting, email building (instantiation of ViewModels from variables), and deciding if a mail request should be enqueued or sent directly. + +The resulting implementation cannot be owned by a single team (since all emails are in a single class), and as a result, anyone can edit any template without the appropriate team being informed. + +To alleviate these issues, all new emails should be implemented using [MJML](../../MailTemplates/README.md) and the `Mailer`. + +## `Mailer` + +The Mailer feature provides a structured, type-safe approach to sending emails in the Bitwarden server application. It +uses Handlebars templates to render both HTML and plain text email content. + +### Architecture + +The Mailer system consists of four main components: + +1. **IMailer** - Service interface for sending emails +2. **BaseMail** - Abstract base class defining email metadata (recipients, subject, category) +3. **BaseMailView** - Abstract base class for email template ViewModels +4. **IMailRenderer** - Internal interface for rendering templates (implemented by `HandlebarMailRenderer`) + +### How To Use + +1. Define a ViewModel that inherits from `BaseMailView` with properties for template data. +2. Define an email class that inherits from `BaseMail` with metadata like `Subject`. +3. Create Handlebars templates (`.html.hbs` and `.text.hbs`) as embedded resources, preferably using the `MJML` [pipeline](../../MailTemplates/Mjml/README.md#development-process), in + a directory in `/src/Core/MailTemplates/Mjml`. +4. Use `IMailer.SendEmail()` to render and send the email. + +### Creating a New Email + +#### Step 1: Define the ViewModel + +Create a class that inherits from `BaseMailView`: + +```csharp +using Bit.Core.Platform.Mailer; + +namespace MyApp.Emails; + +public class WelcomeEmailView : BaseMailView +{ + public required string UserName { get; init; } + public required string ActivationUrl { get; init; } +} +``` + +#### Step 2: Define the email class + +Create a class that inherits from `BaseMail`: + +```csharp +public class WelcomeEmail : BaseMail +{ + public override string Subject => "Welcome to Bitwarden"; +} +``` + +#### Step 3: Create Handlebars templates + +Create two template files as embedded resources next to your ViewModel. + +> [!IMPORTANT] +> The files must be located directly next to the `ViewClass` and match the name of the view. + +**WelcomeEmailView.html.hbs** (HTML version): + +```handlebars +

    Welcome, {{ UserName }}!

    +

    Thank you for joining Bitwarden.

    +

    + Activate your account +

    +

    © {{ CurrentYear }} Bitwarden Inc.

    +``` + +**WelcomeEmailView.text.hbs** (plain text version): + +```handlebars +Welcome, {{ UserName }}! + +Thank you for joining Bitwarden. + +Activate your account: {{ ActivationUrl }} + +� {{ CurrentYear }} Bitwarden Inc. +``` + +**Important**: Template files must be configured as embedded resources in your `.csproj`: + +```xml + + + + +``` + +#### Step 4: Send the email + +Inject `IMailer` and send the email, this may be done in a service, command or some other application layer. + +```csharp +public class SomeService +{ + private readonly IMailer _mailer; + + public SomeService(IMailer mailer) + { + _mailer = mailer; + } + + public async Task SendWelcomeEmailAsync(string email, string userName, string activationUrl) + { + var mail = new WelcomeEmail + { + ToEmails = [email], + View = new WelcomeEmailView + { + UserName = userName, + ActivationUrl = activationUrl + } + }; + + await _mailer.SendEmail(mail); + } +} +``` + +### Advanced Features + +#### Multiple Recipients + +Send to multiple recipients by providing multiple email addresses: + +```csharp +var mail = new WelcomeEmail +{ + ToEmails = ["user1@example.com", "user2@example.com"], + View = new WelcomeEmailView { /* ... */ } +}; +``` + +#### Bypass Suppression List + +For critical emails like account recovery or email OTP, you can bypass the suppression list: + +```csharp +public class PasswordResetEmail : BaseMail +{ + public override string Subject => "Reset Your Password"; + public override bool IgnoreSuppressList => true; // Use with caution +} +``` + +**Warning**: Only use `IgnoreSuppressList = true` for critical account recovery or authentication emails. + +#### Email Categories + +Optionally categorize emails for processing at the upstream email delivery service: + +```csharp +public class MarketingEmail : BaseMail +{ + public override string Subject => "Latest Updates"; + public string? Category => "marketing"; +} +``` + +### Built-in View Properties + +All ViewModels inherit from `BaseMailView`, which provides: + +- **CurrentYear** - The current UTC year (useful for copyright notices) + +```handlebars + +
    © {{ CurrentYear }} Bitwarden Inc.
    +``` + +### Template Naming Convention + +Templates must follow this naming convention: + +- HTML template: `{ViewModelFullName}.html.hbs` +- Text template: `{ViewModelFullName}.text.hbs` + +For example, if your ViewModel is `Bit.Core.Auth.Models.Mail.VerifyEmailView`, the templates must be: + +- `Bit.Core.Auth.Models.Mail.VerifyEmailView.html.hbs` +- `Bit.Core.Auth.Models.Mail.VerifyEmailView.text.hbs` + +## Dependency Injection + +Register the Mailer services in your DI container using the extension method: + +```csharp +using Bit.Core.Platform.Mailer; + +services.AddMailer(); +``` + +Or manually register the services: + +```csharp +using Microsoft.Extensions.DependencyInjection.Extensions; + +services.TryAddSingleton(); +services.TryAddSingleton(); +``` + +### Performance Notes + +- **Template caching** - `HandlebarMailRenderer` automatically caches compiled templates +- **Lazy initialization** - Handlebars is initialized only when first needed +- **Thread-safe** - The renderer is thread-safe for concurrent email rendering + +# Overriding email templates from disk + +The mail services support loading the mail template from disk. This is intended to be used by self-hosted customers who want to modify their email appearance. These overrides are not intended to be used during local development, as any changes there would not be reflected in the templates used in a normal deployment configuration. + +Any customer using this override has worked with Bitwarden support on an approved implementation and has acknowledged that they are responsible for reacting to any changes made to the templates as a part of the Bitwarden development process. This includes, but is not limited to, changes in Handlebars property names, removal of properties from the ViewModel classes, and changes in template names. **Bitwarden is not responsible for maintaining backward compatibility between releases in order to support any overridden emails.** \ No newline at end of file diff --git a/src/Core/Platform/Push/IPushNotificationService.cs b/src/Core/Platform/Push/IPushNotificationService.cs index 32a488b827..b6d7d4d416 100644 --- a/src/Core/Platform/Push/IPushNotificationService.cs +++ b/src/Core/Platform/Push/IPushNotificationService.cs @@ -167,18 +167,17 @@ public interface IPushNotificationService ExcludeCurrentContext = false, }); - Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false) - => PushAsync(new PushNotification + Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false, + PushNotificationLogOutReason? reason = null) + => PushAsync(new PushNotification { Type = PushType.LogOut, Target = NotificationTarget.User, TargetId = userId, - Payload = new UserPushNotification + Payload = new LogOutPushNotification { UserId = userId, -#pragma warning disable BWP0001 // Type or member is obsolete - Date = TimeProvider.GetUtcNow().UtcDateTime, -#pragma warning restore BWP0001 // Type or member is obsolete + Reason = reason }, ExcludeCurrentContext = excludeCurrentContextFromPush, }); diff --git a/src/Core/Platform/Push/PushType.cs b/src/Core/Platform/Push/PushType.cs index 7765c1aa66..9a601ab0d3 100644 --- a/src/Core/Platform/Push/PushType.cs +++ b/src/Core/Platform/Push/PushType.cs @@ -55,7 +55,7 @@ public enum PushType : byte [NotificationInfo("not-specified", typeof(Models.UserPushNotification))] SyncSettings = 10, - [NotificationInfo("not-specified", typeof(Models.UserPushNotification))] + [NotificationInfo("not-specified", typeof(Models.LogOutPushNotification))] LogOut = 11, [NotificationInfo("@bitwarden/team-tools-dev", typeof(Models.SyncSendPushNotification))] @@ -95,5 +95,8 @@ public enum PushType : byte OrganizationBankAccountVerified = 23, [NotificationInfo("@bitwarden/team-billing-dev", typeof(Models.ProviderBankAccountVerifiedPushNotification))] - ProviderBankAccountVerified = 24 + ProviderBankAccountVerified = 24, + + [NotificationInfo("@bitwarden/team-admin-console-dev", typeof(Models.SyncPolicyPushNotification))] + PolicyChanged = 25, } diff --git a/src/Core/Repositories/IOrganizationDomainRepository.cs b/src/Core/Repositories/IOrganizationDomainRepository.cs index d802fe65df..b993cd42fa 100644 --- a/src/Core/Repositories/IOrganizationDomainRepository.cs +++ b/src/Core/Repositories/IOrganizationDomainRepository.cs @@ -17,4 +17,5 @@ public interface IOrganizationDomainRepository : IRepository GetDomainByOrgIdAndDomainNameAsync(Guid orgId, string domainName); Task> GetExpiredOrganizationDomainsAsync(); Task DeleteExpiredAsync(int expirationPeriod); + Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null); } diff --git a/src/Core/Repositories/IPhishingDomainRepository.cs b/src/Core/Repositories/IPhishingDomainRepository.cs deleted file mode 100644 index 2d653b0a43..0000000000 --- a/src/Core/Repositories/IPhishingDomainRepository.cs +++ /dev/null @@ -1,8 +0,0 @@ -namespace Bit.Core.Repositories; - -public interface IPhishingDomainRepository -{ - Task> GetActivePhishingDomainsAsync(); - Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum); - Task GetCurrentChecksumAsync(); -} diff --git a/src/Core/Repositories/IUserRepository.cs b/src/Core/Repositories/IUserRepository.cs index 22effb4329..93316d78bd 100644 --- a/src/Core/Repositories/IUserRepository.cs +++ b/src/Core/Repositories/IUserRepository.cs @@ -1,4 +1,6 @@ -using Bit.Core.Entities; +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; @@ -23,6 +25,7 @@ public interface IUserRepository : IRepository /// Retrieves the data for the requested user IDs and includes an additional property indicating /// whether the user has premium access directly or through an organization. ///
    + [Obsolete("Use GetPremiumAccessByIdsAsync instead. This method will be removed in a future version.")] Task> GetManyWithCalculatedPremiumAsync(IEnumerable ids); /// /// Retrieves the data for the requested user ID and includes additional property indicating @@ -33,8 +36,23 @@ public interface IUserRepository : IRepository /// /// The user ID to retrieve data for. /// User data with calculated premium access; null if nothing is found + [Obsolete("Use GetPremiumAccessAsync instead. This method will be removed in a future version.")] Task GetCalculatedPremiumAsync(Guid userId); /// + /// Retrieves premium access status for multiple users. + /// For internal use - consumers should use IHasPremiumAccessQuery instead. + /// + /// The user IDs to check + /// Collection of UserPremiumAccess objects containing premium status information + Task> GetPremiumAccessByIdsAsync(IEnumerable ids); + /// + /// Retrieves premium access status for a single user. + /// For internal use - consumers should use IHasPremiumAccessQuery instead. + /// + /// The user ID to check + /// UserPremiumAccess object containing premium status information, or null if user not found + Task GetPremiumAccessAsync(Guid userId); + /// /// Sets a new user key and updates all encrypted data. /// Warning: Any user key encrypted data not included will be lost. /// @@ -44,5 +62,19 @@ public interface IUserRepository : IRepository IEnumerable updateDataActions); Task UpdateUserKeyAndEncryptedDataV2Async(User user, IEnumerable updateDataActions); + /// + /// Sets the account cryptographic state to a user in a single transaction. The provided + /// MUST be a V2 encryption state. Passing in a V1 encryption state will throw. + /// Extra actions can be passed in case other user data needs to be updated in the same transaction. + /// + Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null); Task DeleteManyAsync(IEnumerable users); + + UpdateUserData SetKeyConnectorUserKey(Guid userId, string keyConnectorWrappedUserKey); } + +public delegate Task UpdateUserData(Microsoft.Data.SqlClient.SqlConnection? connection = null, + Microsoft.Data.SqlClient.SqlTransaction? transaction = null); diff --git a/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs b/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs deleted file mode 100644 index 2d4ea15b7e..0000000000 --- a/src/Core/Repositories/Implementations/AzurePhishingDomainRepository.cs +++ /dev/null @@ -1,126 +0,0 @@ -using System.Text.Json; -using Bit.Core.PhishingDomainFeatures; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.Logging; - -namespace Bit.Core.Repositories.Implementations; - -public class AzurePhishingDomainRepository : IPhishingDomainRepository -{ - private readonly AzurePhishingDomainStorageService _storageService; - private readonly IDistributedCache _cache; - private readonly ILogger _logger; - private const string _domainsCacheKey = "PhishingDomains_v1"; - private const string _checksumCacheKey = "PhishingDomains_Checksum_v1"; - private static readonly DistributedCacheEntryOptions _cacheOptions = new() - { - AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(24), - SlidingExpiration = TimeSpan.FromHours(1) - }; - - public AzurePhishingDomainRepository( - AzurePhishingDomainStorageService storageService, - IDistributedCache cache, - ILogger logger) - { - _storageService = storageService; - _cache = cache; - _logger = logger; - } - - public async Task> GetActivePhishingDomainsAsync() - { - try - { - var cachedDomains = await _cache.GetStringAsync(_domainsCacheKey); - if (!string.IsNullOrEmpty(cachedDomains)) - { - _logger.LogDebug("Retrieved phishing domains from cache"); - return JsonSerializer.Deserialize>(cachedDomains) ?? []; - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to retrieve phishing domains from cache"); - } - - var domains = await _storageService.GetDomainsAsync(); - - try - { - await _cache.SetStringAsync( - _domainsCacheKey, - JsonSerializer.Serialize(domains), - _cacheOptions); - _logger.LogDebug("Stored {Count} phishing domains in cache", domains.Count); - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to store phishing domains in cache"); - } - - return domains; - } - - public async Task GetCurrentChecksumAsync() - { - try - { - var cachedChecksum = await _cache.GetStringAsync(_checksumCacheKey); - if (!string.IsNullOrEmpty(cachedChecksum)) - { - _logger.LogDebug("Retrieved phishing domain checksum from cache"); - return cachedChecksum; - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to retrieve phishing domain checksum from cache"); - } - - var checksum = await _storageService.GetChecksumAsync(); - - try - { - if (!string.IsNullOrEmpty(checksum)) - { - await _cache.SetStringAsync( - _checksumCacheKey, - checksum, - _cacheOptions); - _logger.LogDebug("Stored phishing domain checksum in cache"); - } - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to store phishing domain checksum in cache"); - } - - return checksum; - } - - public async Task UpdatePhishingDomainsAsync(IEnumerable domains, string checksum) - { - var domainsList = domains.ToList(); - await _storageService.UpdateDomainsAsync(domainsList, checksum); - - try - { - await _cache.SetStringAsync( - _domainsCacheKey, - JsonSerializer.Serialize(domainsList), - _cacheOptions); - - await _cache.SetStringAsync( - _checksumCacheKey, - checksum, - _cacheOptions); - - _logger.LogDebug("Updated phishing domains cache after update operation"); - } - catch (Exception ex) - { - _logger.LogWarning(ex, "Failed to update phishing domains in cache"); - } - } -} diff --git a/src/Core/Resources/SharedResources.en.resx b/src/Core/Resources/SharedResources.en.resx index 28ae70ca96..ca150f2106 100644 --- a/src/Core/Resources/SharedResources.en.resx +++ b/src/Core/Resources/SharedResources.en.resx @@ -508,9 +508,15 @@ Supplied userId and token did not match. + + User should have been defined by this point. + Could not find organization for '{0}' + + Could not find organization user for user '{0}' organization '{1}' + No seats available for organization, '{0}' diff --git a/src/Core/SecretsManager/Entities/SecretVersion.cs b/src/Core/SecretsManager/Entities/SecretVersion.cs new file mode 100644 index 0000000000..cee447bd2a --- /dev/null +++ b/src/Core/SecretsManager/Entities/SecretVersion.cs @@ -0,0 +1,28 @@ +#nullable enable +using Bit.Core.Entities; +using Bit.Core.Utilities; + +namespace Bit.Core.SecretsManager.Entities; + +public class SecretVersion : ITableObject +{ + public Guid Id { get; set; } + + public Guid SecretId { get; set; } + + public string Value { get; set; } = string.Empty; + + public DateTime VersionDate { get; set; } + + public Guid? EditorServiceAccountId { get; set; } + + public Guid? EditorOrganizationUserId { get; set; } + + public void SetNewId() + { + if (Id == default(Guid)) + { + Id = CoreHelpers.GenerateComb(); + } + } +} diff --git a/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs b/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs new file mode 100644 index 0000000000..b6dd1d778d --- /dev/null +++ b/src/Core/SecretsManager/Repositories/ISecretVersionRepository.cs @@ -0,0 +1,12 @@ +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Core.SecretsManager.Repositories; + +public interface ISecretVersionRepository +{ + Task GetByIdAsync(Guid id); + Task> GetManyBySecretIdAsync(Guid secretId); + Task> GetManyByIdsAsync(IEnumerable ids); + Task CreateAsync(SecretVersion secretVersion); + Task DeleteManyByIdAsync(IEnumerable ids); +} diff --git a/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs b/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs new file mode 100644 index 0000000000..caa5d96a7c --- /dev/null +++ b/src/Core/SecretsManager/Repositories/Noop/NoopSecretVersionRepository.cs @@ -0,0 +1,31 @@ +using Bit.Core.SecretsManager.Entities; + +namespace Bit.Core.SecretsManager.Repositories.Noop; + +public class NoopSecretVersionRepository : ISecretVersionRepository +{ + public Task GetByIdAsync(Guid id) + { + return Task.FromResult(null as SecretVersion); + } + + public Task> GetManyBySecretIdAsync(Guid secretId) + { + return Task.FromResult(Enumerable.Empty()); + } + + public Task CreateAsync(SecretVersion secretVersion) + { + return Task.FromResult(secretVersion); + } + + public Task DeleteManyByIdAsync(IEnumerable ids) + { + return Task.CompletedTask; + } + + public Task> GetManyByIdsAsync(IEnumerable ids) + { + return Task.FromResult(Enumerable.Empty()); + } +} diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs deleted file mode 100644 index 8a41263956..0000000000 --- a/src/Core/Services/IStripeAdapter.cs +++ /dev/null @@ -1,65 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; - -namespace Bit.Core.Services; - -public interface IStripeAdapter -{ - Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); - Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null); - Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options); - Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); - - /// - /// Retrieves a subscription object for a provider. - /// - /// The subscription ID. - /// The provider ID. - /// Additional options. - /// The subscription object. - /// Thrown when the subscription doesn't belong to the provider. - Task ProviderSubscriptionGetAsync(string id, Guid providerId, Stripe.SubscriptionGetOptions options = null); - - Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); - Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); - Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); - Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); - Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); - Task> InvoiceSearchAsync(InvoiceSearchOptions options); - Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); - Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); - IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); - Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); - Task> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null); - Task> ChargeListAsync(Stripe.ChargeListOptions options); - Task RefundCreateAsync(Stripe.RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); - Task> PriceListAsync(Stripe.PriceListOptions options = null); - Task SetupIntentCreate(SetupIntentCreateOptions options); - Task> SetupIntentList(SetupIntentListOptions options); - Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null); - Task SetupIntentGet(string id, SetupIntentGetOptions options = null); - Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options); - Task> TestClockListAsync(); - Task PriceGetAsync(string id, PriceGetOptions options = null); -} diff --git a/src/Core/Services/IStripeSyncService.cs b/src/Core/Services/IStripeSyncService.cs deleted file mode 100644 index 655998805e..0000000000 --- a/src/Core/Services/IStripeSyncService.cs +++ /dev/null @@ -1,6 +0,0 @@ -namespace Bit.Core.Services; - -public interface IStripeSyncService -{ - Task UpdateCustomerEmailAddress(string gatewayCustomerId, string emailAddress); -} diff --git a/src/Core/Services/IUserService.cs b/src/Core/Services/IUserService.cs index 412f9db36e..a531883db1 100644 --- a/src/Core/Services/IUserService.cs +++ b/src/Core/Services/IUserService.cs @@ -4,7 +4,6 @@ using System.Security.Claims; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Entities; using Bit.Core.Enums; @@ -34,6 +33,8 @@ public interface IUserService Task ChangeEmailAsync(User user, string masterPassword, string newEmail, string newMasterPassword, string token, string key); Task ChangePasswordAsync(User user, string masterPassword, string newMasterPassword, string passwordHint, string key); + // TODO removed with https://bitwarden.atlassian.net/browse/PM-27328 + [Obsolete("Use ISetKeyConnectorKeyCommand instead. This method will be removed in a future version.")] Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier); Task ConvertToKeyConnectorAsync(User user); Task AdminResetPasswordAsync(OrganizationUserType type, Guid orgId, Guid id, string newMasterPassword, string key); @@ -60,11 +61,23 @@ public interface IUserService Task CheckPasswordAsync(User user, string password); /// /// Checks if the user has access to premium features, either through a personal subscription or through an organization. + /// + /// This is the preferred way to definitively know if a user has access to premium features when you already have the User object. /// /// user being acted on /// true if they can access premium; false otherwise. - Task CanAccessPremium(ITwoFactorProvidersUser user); - Task HasPremiumFromOrganization(ITwoFactorProvidersUser user); + Task CanAccessPremium(User user); + + /// + /// Checks if the user has inherited access to premium features through an organization. + /// + /// This primarily serves as a means to communicate to the client when a user has inherited their premium status + /// through an organization. Feature gating logic probably should not be behind this check. + /// + /// user being acted on + /// true if they can access premium because of organization membership; false otherwise. + [Obsolete("Use IHasPremiumAccessQuery.HasPremiumFromOrganizationAsync instead. This method will be removed in a future version.")] + Task HasPremiumFromOrganization(User user); Task GenerateSignInTokenAsync(User user, string purpose); Task UpdatePasswordHash(User user, string newPassword, diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs deleted file mode 100644 index 4863baf73e..0000000000 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ /dev/null @@ -1,302 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Models.BitStripe; -using Stripe; -using Stripe.Tax; - -namespace Bit.Core.Services; - -public class StripeAdapter : IStripeAdapter -{ - private readonly Stripe.CustomerService _customerService; - private readonly Stripe.SubscriptionService _subscriptionService; - private readonly Stripe.InvoiceService _invoiceService; - private readonly Stripe.PaymentMethodService _paymentMethodService; - private readonly Stripe.TaxIdService _taxIdService; - private readonly Stripe.ChargeService _chargeService; - private readonly Stripe.RefundService _refundService; - private readonly Stripe.CardService _cardService; - private readonly Stripe.BankAccountService _bankAccountService; - private readonly Stripe.PlanService _planService; - private readonly Stripe.PriceService _priceService; - private readonly Stripe.SetupIntentService _setupIntentService; - private readonly Stripe.TestHelpers.TestClockService _testClockService; - private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; - private readonly Stripe.Tax.RegistrationService _taxRegistrationService; - private readonly CalculationService _calculationService; - - public StripeAdapter() - { - _customerService = new Stripe.CustomerService(); - _subscriptionService = new Stripe.SubscriptionService(); - _invoiceService = new Stripe.InvoiceService(); - _paymentMethodService = new Stripe.PaymentMethodService(); - _taxIdService = new Stripe.TaxIdService(); - _chargeService = new Stripe.ChargeService(); - _refundService = new Stripe.RefundService(); - _cardService = new Stripe.CardService(); - _bankAccountService = new Stripe.BankAccountService(); - _priceService = new Stripe.PriceService(); - _planService = new Stripe.PlanService(); - _setupIntentService = new SetupIntentService(); - _testClockService = new Stripe.TestHelpers.TestClockService(); - _customerBalanceTransactionService = new CustomerBalanceTransactionService(); - _taxRegistrationService = new Stripe.Tax.RegistrationService(); - _calculationService = new CalculationService(); - } - - public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) - { - return _customerService.CreateAsync(options); - } - - public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) - { - return _customerService.GetAsync(id, options); - } - - public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) - { - return _customerService.UpdateAsync(id, options); - } - - public Task CustomerDeleteAsync(string id) - { - return _customerService.DeleteAsync(id); - } - - public async Task> CustomerListPaymentMethods(string id, - CustomerListPaymentMethodsOptions options = null) - { - var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); - return paymentMethods.Data; - } - - public async Task CustomerBalanceTransactionCreate(string customerId, - CustomerBalanceTransactionCreateOptions options) - => await _customerBalanceTransactionService.CreateAsync(customerId, options); - - public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) - { - return _subscriptionService.CreateAsync(options); - } - - public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) - { - return _subscriptionService.GetAsync(id, options); - } - - public async Task ProviderSubscriptionGetAsync( - string id, - Guid providerId, - SubscriptionGetOptions options = null) - { - var subscription = await _subscriptionService.GetAsync(id, options); - if (subscription.Metadata.TryGetValue("providerId", out var value) && value == providerId.ToString()) - { - return subscription; - } - - throw new InvalidOperationException("Subscription does not belong to the provider."); - } - - public Task SubscriptionUpdateAsync(string id, - Stripe.SubscriptionUpdateOptions options = null) - { - return _subscriptionService.UpdateAsync(id, options); - } - - public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) - { - return _subscriptionService.CancelAsync(Id, options); - } - - public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) - { - return _invoiceService.UpcomingAsync(options); - } - - public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) - { - return _invoiceService.GetAsync(id, options); - } - - public async Task> InvoiceListAsync(StripeInvoiceListOptions options) - { - if (!options.SelectAll) - { - return (await _invoiceService.ListAsync(options.ToInvoiceListOptions())).Data; - } - - options.Limit = 100; - - var invoices = new List(); - - await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) - { - invoices.Add(invoice); - } - - return invoices; - } - - public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options) - { - return _invoiceService.CreatePreviewAsync(options); - } - - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) - => (await _invoiceService.SearchAsync(options)).Data; - - public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) - { - return _invoiceService.UpdateAsync(id, options); - } - - public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) - { - return _invoiceService.FinalizeInvoiceAsync(id, options); - } - - public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) - { - return _invoiceService.SendInvoiceAsync(id, options); - } - - public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) - { - return _invoiceService.PayAsync(id, options); - } - - public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) - { - return _invoiceService.DeleteAsync(id, options); - } - - public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) - { - return _invoiceService.VoidInvoiceAsync(id, options); - } - - public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) - { - return _paymentMethodService.ListAutoPaging(options); - } - - public IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options) - => _paymentMethodService.ListAutoPagingAsync(options); - - public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) - { - return _paymentMethodService.AttachAsync(id, options); - } - - public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) - { - return _paymentMethodService.DetachAsync(id, options); - } - - public Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null) - { - return _planService.GetAsync(id, options); - } - - public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) - { - return _taxIdService.CreateAsync(id, options); - } - - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - Stripe.TaxIdDeleteOptions options = null) - { - return _taxIdService.DeleteAsync(customerId, taxIdId); - } - - public Task> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null) - { - return _taxRegistrationService.ListAsync(options); - } - - public Task> ChargeListAsync(Stripe.ChargeListOptions options) - { - return _chargeService.ListAsync(options); - } - - public Task RefundCreateAsync(Stripe.RefundCreateOptions options) - { - return _refundService.CreateAsync(options); - } - - public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) - { - return _cardService.DeleteAsync(customerId, cardId, options); - } - - public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) - { - return _bankAccountService.CreateAsync(customerId, options); - } - - public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) - { - return _bankAccountService.DeleteAsync(customerId, bankAccount, options); - } - - public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) - { - if (!options.SelectAll) - { - return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; - } - - options.Limit = 100; - var items = new List(); - await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) - { - items.Add(i); - } - return items; - } - - public async Task> PriceListAsync(Stripe.PriceListOptions options = null) - { - return await _priceService.ListAsync(options); - } - - public Task SetupIntentCreate(SetupIntentCreateOptions options) - => _setupIntentService.CreateAsync(options); - - public async Task> SetupIntentList(SetupIntentListOptions options) - { - var setupIntents = await _setupIntentService.ListAsync(options); - - return setupIntents.Data; - } - - public Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null) - => _setupIntentService.CancelAsync(id, options); - - public Task SetupIntentGet(string id, SetupIntentGetOptions options = null) - => _setupIntentService.GetAsync(id, options); - - public Task SetupIntentVerifyMicroDeposit(string id, SetupIntentVerifyMicrodepositsOptions options) - => _setupIntentService.VerifyMicrodepositsAsync(id, options); - - public async Task> TestClockListAsync() - { - var items = new List(); - var options = new Stripe.TestHelpers.TestClockListOptions() - { - Limit = 100 - }; - await foreach (var i in _testClockService.ListAutoPagingAsync(options)) - { - items.Add(i); - } - return items; - } - - public Task PriceGetAsync(string id, PriceGetOptions options = null) - => _priceService.GetAsync(id, options); -} diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index a36b9e37cc..498721238b 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -14,10 +14,11 @@ using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Premium.Queries; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; @@ -57,7 +58,7 @@ public class UserService : UserManager, IUserService private readonly ILicensingService _licenseService; private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly IFido2 _fido2; @@ -72,6 +73,8 @@ public class UserService : UserManager, IUserService private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IDistributedCache _distributedCache; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IPricingClient _pricingClient; + private readonly IHasPremiumAccessQuery _hasPremiumAccessQuery; public UserService( IUserRepository userRepository, @@ -92,7 +95,7 @@ public class UserService : UserManager, IUserService ILicensingService licenseService, IEventService eventService, IApplicationCacheService applicationCacheService, - IPaymentService paymentService, + IStripePaymentService paymentService, IPolicyRepository policyRepository, IPolicyService policyService, IFido2 fido2, @@ -106,7 +109,9 @@ public class UserService : UserManager, IUserService IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IDistributedCache distributedCache, - IPolicyRequirementQuery policyRequirementQuery) + IPolicyRequirementQuery policyRequirementQuery, + IPricingClient pricingClient, + IHasPremiumAccessQuery hasPremiumAccessQuery) : base( store, optionsAccessor, @@ -146,6 +151,8 @@ public class UserService : UserManager, IUserService _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _distributedCache = distributedCache; _policyRequirementQuery = policyRequirementQuery; + _pricingClient = pricingClient; + _hasPremiumAccessQuery = hasPremiumAccessQuery; } public Guid? GetProperUserId(ClaimsPrincipal principal) @@ -337,6 +344,12 @@ public class UserService : UserManager, IUserService await _mailService.SendMasterPasswordHintEmailAsync(email, user.MasterPasswordHint); } + /// + /// Initiates WebAuthn 2FA credential registration and generates a challenge for adding a new security key. + /// + /// The current user. + /// + /// Maximum allowed number of credentials already registered. public async Task StartWebAuthnRegistrationAsync(User user) { var providers = user.GetTwoFactorProviders(); @@ -357,6 +370,17 @@ public class UserService : UserManager, IUserService provider.MetaData = new Dictionary(); } + // Boundary validation to provide a better UX. There is also second-level enforcement at persistence time. + var maximumAllowedCredentialCount = await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id) + ? _globalSettings.WebAuthn.PremiumMaximumAllowedCredentials + : _globalSettings.WebAuthn.NonPremiumMaximumAllowedCredentials; + // Count only saved credentials ("Key{id}") toward the limit. + if (provider.MetaData.Count(k => k.Key.StartsWith("Key")) >= + maximumAllowedCredentialCount) + { + throw new BadRequestException("Maximum allowed WebAuthn credential count exceeded."); + } + var fidoUser = new Fido2User { DisplayName = user.Name, @@ -395,6 +419,17 @@ public class UserService : UserManager, IUserService return false; } + // Persistence-time validation for comprehensive enforcement. There is also boundary validation for best-possible UX. + var maximumAllowedCredentialCount = await _hasPremiumAccessQuery.HasPremiumAccessAsync(user.Id) + ? _globalSettings.WebAuthn.PremiumMaximumAllowedCredentials + : _globalSettings.WebAuthn.NonPremiumMaximumAllowedCredentials; + // Count only saved credentials ("Key{id}") toward the limit. + if (provider.MetaData.Count(k => k.Key.StartsWith("Key")) >= + maximumAllowedCredentialCount) + { + throw new BadRequestException("Maximum allowed WebAuthn credential count exceeded."); + } + var options = CredentialCreateOptions.FromJson((string)pendingValue); // Callback to ensure credential ID is unique. Always return true since we don't care if another @@ -531,7 +566,7 @@ public class UserService : UserManager, IUserService try { - await _stripeSyncService.UpdateCustomerEmailAddress(user.GatewayCustomerId, + await _stripeSyncService.UpdateCustomerEmailAddressAsync(user.GatewayCustomerId, user.BillingEmailAddress()); } catch (Exception ex) @@ -614,6 +649,7 @@ public class UserService : UserManager, IUserService return IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch()); } + // TODO removed with https://bitwarden.atlassian.net/browse/PM-27328 public async Task SetKeyConnectorKeyAsync(User user, string key, string orgIdentifier) { var identityResult = CheckCanUseKeyConnector(user); @@ -864,7 +900,7 @@ public class UserService : UserManager, IUserService } string paymentIntentClientSecret = null; - IPaymentService paymentService = null; + IStripePaymentService paymentService = null; if (_globalSettings.SelfHosted) { if (license == null || !_licenseService.VerifyLicense(license)) @@ -901,7 +937,6 @@ public class UserService : UserManager, IUserService } else { - user.MaxStorageGb = (short)(1 + additionalStorageGb); user.LicenseKey = CoreHelpers.SecureRandomString(20); } @@ -972,8 +1007,10 @@ public class UserService : UserManager, IUserService throw new BadRequestException("Not a premium user."); } - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, - StripeConstants.Prices.StoragePlanPersonal); + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + + var baseStorageGb = (short)premiumPlan.Storage.Provided; + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, premiumPlan.Storage.StripePriceId, baseStorageGb); await SaveUserAsync(user); return secret; } @@ -1100,7 +1137,7 @@ public class UserService : UserManager, IUserService return success; } - public async Task CanAccessPremium(ITwoFactorProvidersUser user) + public async Task CanAccessPremium(User user) { var userId = user.GetUserId(); if (!userId.HasValue) @@ -1108,10 +1145,15 @@ public class UserService : UserManager, IUserService return false; } - return user.GetPremium() || await this.HasPremiumFromOrganization(user); + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return user.Premium || await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value); + } + + return user.Premium || await HasPremiumFromOrganization(user); } - public async Task HasPremiumFromOrganization(ITwoFactorProvidersUser user) + public async Task HasPremiumFromOrganization(User user) { var userId = user.GetUserId(); if (!userId.HasValue) @@ -1119,6 +1161,11 @@ public class UserService : UserManager, IUserService return false; } + if (_featureService.IsEnabled(FeatureFlagKeys.PremiumAccessQuery)) + { + return await _hasPremiumAccessQuery.HasPremiumFromOrganizationAsync(userId.Value); + } + // orgUsers in the Invited status are not associated with a userId yet, so this will get // orgUsers in Accepted and Confirmed states only var orgUsers = await _organizationUserRepository.GetManyByUserAsync(userId.Value); @@ -1134,6 +1181,7 @@ public class UserService : UserManager, IUserService orgAbility.UsersGetPremium && orgAbility.Enabled); } + public async Task GenerateSignInTokenAsync(User user, string purpose) { var token = await GenerateUserTokenAsync(user, Options.Tokens.PasswordResetTokenProvider, diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 250daf0007..60a1fda19f 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -2,14 +2,12 @@ #nullable disable using Bit.Core.Auth.Settings; -using Bit.Core.Settings.LoggingSettings; namespace Bit.Core.Settings; public class GlobalSettings : IGlobalSettings { private string _mailTemplateDirectory; - private string _logDirectory; private string _licenseDirectory; public GlobalSettings() @@ -21,18 +19,10 @@ public class GlobalSettings : IGlobalSettings } public bool SelfHosted { get; set; } - public bool UnifiedDeployment { get; set; } + public bool LiteDeployment { get; set; } public virtual string KnownProxies { get; set; } public virtual string SiteName { get; set; } public virtual string ProjectName { get; set; } - public virtual string LogDirectory - { - get => BuildDirectory(_logDirectory, "/logs"); - set => _logDirectory = value; - } - public virtual bool LogDirectoryByProject { get; set; } = true; - public virtual long? LogRollBySizeLimit { get; set; } - public virtual bool EnableDevLogging { get; set; } = false; public virtual string LicenseDirectory { get => BuildDirectory(_licenseDirectory, "/core/licenses"); @@ -62,22 +52,21 @@ public class GlobalSettings : IGlobalSettings public virtual SqlSettings MySql { get; set; } = new SqlSettings(); public virtual SqlSettings Sqlite { get; set; } = new SqlSettings() { ConnectionString = "Data Source=:memory:" }; public virtual SlackSettings Slack { get; set; } = new SlackSettings(); + public virtual TeamsSettings Teams { get; set; } = new TeamsSettings(); public virtual EventLoggingSettings EventLogging { get; set; } = new EventLoggingSettings(); public virtual MailSettings Mail { get; set; } = new MailSettings(); public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); - public virtual ConnectionStringSettings Events { get; set; } = new ConnectionStringSettings(); + public virtual AzureQueueEventSettings Events { get; set; } = new AzureQueueEventSettings(); public virtual DistributedCacheSettings DistributedCache { get; set; } = new DistributedCacheSettings(); public virtual NotificationsSettings Notifications { get; set; } = new NotificationsSettings(); public virtual IFileStorageSettings Attachment { get; set; } public virtual FileStorageSettings Send { get; set; } public virtual IdentityServerSettings IdentityServer { get; set; } = new IdentityServerSettings(); public virtual DataProtectionSettings DataProtection { get; set; } - public virtual SentrySettings Sentry { get; set; } = new SentrySettings(); - public virtual SyslogSettings Syslog { get; set; } = new SyslogSettings(); - public virtual ILogLevelSettings MinLogLevel { get; set; } = new LogLevelSettings(); public virtual NotificationHubPoolSettings NotificationHubPool { get; set; } = new(); public virtual YubicoSettings Yubico { get; set; } = new YubicoSettings(); public virtual DuoSettings Duo { get; set; } = new DuoSettings(); + public virtual WebAuthnSettings WebAuthn { get; set; } = new WebAuthnSettings(); public virtual BraintreeSettings Braintree { get; set; } = new BraintreeSettings(); public virtual ImportCiphersLimitationSettings ImportCiphersLimitation { get; set; } = new ImportCiphersLimitationSettings(); public virtual BitPaySettings BitPay { get; set; } = new BitPaySettings(); @@ -93,7 +82,6 @@ public class GlobalSettings : IGlobalSettings public virtual ILaunchDarklySettings LaunchDarkly { get; set; } = new LaunchDarklySettings(); public virtual string DevelopmentDirectory { get; set; } public virtual IWebPushSettings WebPush { get; set; } = new WebPushSettings(); - public virtual IPhishingDomainSettings PhishingDomain { get; set; } = new PhishingDomainSettings(); public virtual int SendAccessTokenLifetimeInMinutes { get; set; } = 5; public virtual bool EnableEmailVerification { get; set; } @@ -295,6 +283,15 @@ public class GlobalSettings : IGlobalSettings public virtual string Scopes { get; set; } } + public class TeamsSettings + { + public virtual string LoginBaseUrl { get; set; } = "https://login.microsoftonline.com"; + public virtual string GraphBaseUrl { get; set; } = "https://graph.microsoft.com/v1.0"; + public virtual string ClientId { get; set; } + public virtual string ClientSecret { get; set; } + public virtual string Scopes { get; set; } + } + public class EventLoggingSettings { public AzureServiceBusSettings AzureServiceBus { get; set; } = new AzureServiceBusSettings(); @@ -320,6 +317,8 @@ public class GlobalSettings : IGlobalSettings public virtual string HecIntegrationSubscriptionName { get; set; } = "integration-hec-subscription"; public virtual string DatadogEventSubscriptionName { get; set; } = "events-datadog-subscription"; public virtual string DatadogIntegrationSubscriptionName { get; set; } = "integration-datadog-subscription"; + public virtual string TeamsEventSubscriptionName { get; set; } = "events-teams-subscription"; + public virtual string TeamsIntegrationSubscriptionName { get; set; } = "integration-teams-subscription"; public string ConnectionString { @@ -364,6 +363,9 @@ public class GlobalSettings : IGlobalSettings public virtual string DatadogEventsQueueName { get; set; } = "events-datadog-queue"; public virtual string DatadogIntegrationQueueName { get; set; } = "integration-datadog-queue"; public virtual string DatadogIntegrationRetryQueueName { get; set; } = "integration-datadog-retry-queue"; + public virtual string TeamsEventsQueueName { get; set; } = "events-teams-queue"; + public virtual string TeamsIntegrationQueueName { get; set; } = "integration-teams-queue"; + public virtual string TeamsIntegrationRetryQueueName { get; set; } = "integration-teams-retry-queue"; public string HostName { @@ -393,6 +395,24 @@ public class GlobalSettings : IGlobalSettings } } + public class AzureQueueEventSettings : IConnectionStringSettings + { + private string _connectionString; + private string _queueName; + + public string ConnectionString + { + get => _connectionString; + set => _connectionString = value?.Trim('"'); + } + + public string QueueName + { + get => _queueName; + set => _queueName = value?.Trim('"'); + } + } + public class ConnectionStringSettings : IConnectionStringSettings { private string _connectionString; @@ -481,7 +501,7 @@ public class GlobalSettings : IGlobalSettings public string CertificatePassword { get; set; } public string RedisConnectionString { get; set; } public string CosmosConnectionString { get; set; } - public string LicenseKey { get; set; } = "eyJhbGciOiJQUzI1NiIsImtpZCI6IklkZW50aXR5U2VydmVyTGljZW5zZWtleS83Y2VhZGJiNzgxMzA0NjllODgwNjg5MTAyNTQxNGYxNiIsInR5cCI6ImxpY2Vuc2Urand0In0.eyJpc3MiOiJodHRwczovL2R1ZW5kZXNvZnR3YXJlLmNvbSIsImF1ZCI6IklkZW50aXR5U2VydmVyIiwiaWF0IjoxNzM0NTY2NDAwLCJleHAiOjE3NjQ5NzkyMDAsImNvbXBhbnlfbmFtZSI6IkJpdHdhcmRlbiBJbmMuIiwiY29udGFjdF9pbmZvIjoiY29udGFjdEBkdWVuZGVzb2Z0d2FyZS5jb20iLCJlZGl0aW9uIjoiU3RhcnRlciIsImlkIjoiNjg3OCIsImZlYXR1cmUiOlsiaXN2IiwidW5saW1pdGVkX2NsaWVudHMiXSwicHJvZHVjdCI6IkJpdHdhcmRlbiJ9.TYc88W_t2t0F2AJV3rdyKwGyQKrKFriSAzm1tWFNHNR9QizfC-8bliGdT4Wgeie-ynCXs9wWaF-sKC5emg--qS7oe2iIt67Qd88WS53AwgTvAddQRA4NhGB1R7VM8GAikLieSos-DzzwLYRgjZdmcsprItYGSJuY73r-7-F97ta915majBytVxGF966tT9zF1aYk0bA8FS6DcDYkr5f7Nsy8daS_uIUAgNa_agKXtmQPqKujqtUb6rgWEpSp4OcQcG-8Dpd5jHqoIjouGvY-5LTgk5WmLxi_m-1QISjxUJrUm-UGao3_VwV5KFGqYrz8csdTl-HS40ihWcsWnrV0ug"; + public string LicenseKey { get; set; } = "eyJhbGciOiJQUzI1NiIsImtpZCI6IklkZW50aXR5U2VydmVyTGljZW5zZUtleS83Y2VhZGJiNzgxMzA0NjllODgwNjg5MTAyNTQxNGYxNiIsInR5cCI6ImxpY2Vuc2Urand0In0.eyJpc3MiOiJodHRwczovL2R1ZW5kZXNvZnR3YXJlLmNvbSIsImF1ZCI6IklkZW50aXR5U2VydmVyIiwiaWF0IjoxNzY1MDY1NjAwLCJleHAiOjE3OTY1MTUyMDAsImNvbXBhbnlfbmFtZSI6IkJpdHdhcmRlbiBJbmMuIiwiY29udGFjdF9pbmZvIjoiY29udGFjdEBkdWVuZGVzb2Z0d2FyZS5jb20iLCJlZGl0aW9uIjoiU3RhcnRlciIsImlkIjoiOTUxNSIsImZlYXR1cmUiOlsiaXN2IiwidW5saW1pdGVkX2NsaWVudHMiXSwiY2xpZW50X2xpbWl0IjowfQ.rWUsq-XBKNwPG7BRKG-vShXHuyHLHJCh0sEWdWT4Rkz4ArIPOAepEp9wNya-hxFKkBTFlPaQ5IKk4wDTvkQkuq1qaI_v6kSCdaP9fvXp0rmh4KcFEffVLB-wAOK2S2Cld5DzdyCoskUUfwNQP7xuLsz2Ydxe_whSRIdv8bsMbvTC3Kl8PYZPZ4MxqW8rSZ_mEuCpSe5-Q40sB7aiu_7YmWLJaKrfBTIqYH-XuzQj36Aemoei0efcntej-gvxovy-5SiSEsGuRZj41rjEZYOuj5KgHihJViO1VDHK6CNtlu2Ks8bkv6G2hO-TkF16Y28ywEG_beLEf_s5dzhbDBDbvA"; /// /// Sliding lifetime of a refresh token in seconds. /// @@ -533,59 +553,11 @@ public class GlobalSettings : IGlobalSettings } } - public class SentrySettings - { - public string Dsn { get; set; } - } - public class NotificationsSettings : ConnectionStringSettings { public string RedisConnectionString { get; set; } } - public class SyslogSettings - { - /// - /// The connection string used to connect to a remote syslog server over TCP or UDP, or to connect locally. - /// - /// - /// The connection string will be parsed using to extract the protocol, host name and port number. - /// - /// - /// Supported protocols are: - /// - /// UDP (use udp://) - /// TCP (use tcp://) - /// TLS over TCP (use tls://) - /// - /// - /// - /// - /// A remote server (logging.dev.example.com) is listening on UDP (port 514): - /// - /// udp://logging.dev.example.com:514. - /// - public string Destination { get; set; } - /// - /// The absolute path to a Certificate (DER or Base64 encoded with private key). - /// - /// - /// The certificate path and are passed into the . - /// The file format of the certificate may be binary encoded (DER) or base64. If the private key is encrypted, provide the password in , - /// - public string CertificatePath { get; set; } - /// - /// The password for the encrypted private key in the certificate supplied in . - /// - /// - public string CertificatePassword { get; set; } - /// - /// The thumbprint of the certificate in the X.509 certificate store for personal certificates for the user account running Bitwarden. - /// - /// - public string CertificateThumbprint { get; set; } - } - public class NotificationHubSettings { private string _connectionString; @@ -642,6 +614,12 @@ public class GlobalSettings : IGlobalSettings public string AKey { get; set; } } + public class WebAuthnSettings + { + public int PremiumMaximumAllowedCredentials { get; set; } = 10; + public int NonPremiumMaximumAllowedCredentials { get; set; } = 5; + } + public class BraintreeSettings { public bool Production { get; set; } @@ -662,6 +640,7 @@ public class GlobalSettings : IGlobalSettings public bool Production { get; set; } public string Token { get; set; } public string NotificationUrl { get; set; } + public string WebhookKey { get; set; } } public class InstallationSettings : IInstallationSettings @@ -717,12 +696,6 @@ public class GlobalSettings : IGlobalSettings public int MaxNetworkRetries { get; set; } = 2; } - public class PhishingDomainSettings : IPhishingDomainSettings - { - public string UpdateUrl { get; set; } - public string ChecksumUrl { get; set; } - } - public class DistributedIpRateLimitingSettings { public string RedisConnectionString { get; set; } @@ -767,6 +740,30 @@ public class GlobalSettings : IGlobalSettings { public virtual IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); public virtual IConnectionStringSettings Cosmos { get; set; } = new ConnectionStringSettings(); + public ExtendedCacheSettings DefaultExtendedCache { get; set; } = new ExtendedCacheSettings(); + } + + /// + /// A collection of Settings for customizing the FusionCache used in extended caching. Defaults are + /// provided for every attribute so that only specific values need to be overridden if needed. + /// + public class ExtendedCacheSettings + { + public bool EnableDistributedCache { get; set; } = true; + public bool UseSharedDistributedCache { get; set; } = true; + public IConnectionStringSettings Redis { get; set; } = new ConnectionStringSettings(); + public TimeSpan Duration { get; set; } = TimeSpan.FromMinutes(30); + public bool IsFailSafeEnabled { get; set; } = true; + public TimeSpan FailSafeMaxDuration { get; set; } = TimeSpan.FromHours(2); + public TimeSpan FailSafeThrottleDuration { get; set; } = TimeSpan.FromSeconds(30); + public float? EagerRefreshThreshold { get; set; } = 0.9f; + public TimeSpan FactorySoftTimeout { get; set; } = TimeSpan.FromMilliseconds(100); + public TimeSpan FactoryHardTimeout { get; set; } = TimeSpan.FromMilliseconds(1500); + public TimeSpan DistributedCacheSoftTimeout { get; set; } = TimeSpan.FromSeconds(1); + public TimeSpan DistributedCacheHardTimeout { get; set; } = TimeSpan.FromSeconds(2); + public bool AllowBackgroundDistributedCacheOperations { get; set; } = true; + public TimeSpan JitterMaxDuration { get; set; } = TimeSpan.FromSeconds(2); + public TimeSpan DistributedCacheCircuitBreakerDuration { get; set; } = TimeSpan.FromSeconds(30); } public class WebPushSettings : IWebPushSettings diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index d77842373e..c316836d09 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -6,7 +6,7 @@ public interface IGlobalSettings { // This interface exists for testing. Add settings here as needed for testing bool SelfHosted { get; set; } - bool UnifiedDeployment { get; set; } + bool LiteDeployment { get; set; } string KnownProxies { get; set; } string ProjectName { get; set; } bool EnableCloudCommunication { get; set; } @@ -20,7 +20,6 @@ public interface IGlobalSettings IConnectionStringSettings Storage { get; set; } IBaseServiceUriSettings BaseServiceUri { get; set; } ISsoSettings Sso { get; set; } - ILogLevelSettings MinLogLevel { get; set; } IPasswordlessAuthSettings PasswordlessAuth { get; set; } IDomainVerificationSettings DomainVerification { get; set; } ILaunchDarklySettings LaunchDarkly { get; set; } @@ -29,5 +28,5 @@ public interface IGlobalSettings string DevelopmentDirectory { get; set; } IWebPushSettings WebPush { get; set; } GlobalSettings.EventLoggingSettings EventLogging { get; set; } - IPhishingDomainSettings PhishingDomain { get; set; } + GlobalSettings.WebAuthnSettings WebAuthn { get; set; } } diff --git a/src/Core/Settings/ILogLevelSettings.cs b/src/Core/Settings/ILogLevelSettings.cs deleted file mode 100644 index b3cedf083c..0000000000 --- a/src/Core/Settings/ILogLevelSettings.cs +++ /dev/null @@ -1,74 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings; - -public interface ILogLevelSettings -{ - IBillingLogLevelSettings BillingSettings { get; set; } - IApiLogLevelSettings ApiSettings { get; set; } - IIdentityLogLevelSettings IdentitySettings { get; set; } - IScimLogLevelSettings ScimSettings { get; set; } - ISsoLogLevelSettings SsoSettings { get; set; } - IAdminLogLevelSettings AdminSettings { get; set; } - IEventsLogLevelSettings EventsSettings { get; set; } - IEventsProcessorLogLevelSettings EventsProcessorSettings { get; set; } - IIconsLogLevelSettings IconsSettings { get; set; } - INotificationsLogLevelSettings NotificationsSettings { get; set; } -} - -public interface IBillingLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel Jobs { get; set; } -} - -public interface IApiLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } - LogEventLevel IpRateLimit { get; set; } -} - -public interface IIdentityLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } - LogEventLevel IpRateLimit { get; set; } -} - -public interface IScimLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface ISsoLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IAdminLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IEventsLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } -} - -public interface IEventsProcessorLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface IIconsLogLevelSettings -{ - LogEventLevel Default { get; set; } -} - -public interface INotificationsLogLevelSettings -{ - LogEventLevel Default { get; set; } - LogEventLevel IdentityToken { get; set; } -} diff --git a/src/Core/Settings/IPhishingDomainSettings.cs b/src/Core/Settings/IPhishingDomainSettings.cs deleted file mode 100644 index 2e4a901a5a..0000000000 --- a/src/Core/Settings/IPhishingDomainSettings.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.Settings; - -public interface IPhishingDomainSettings -{ - string UpdateUrl { get; set; } - string ChecksumUrl { get; set; } -} diff --git a/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs deleted file mode 100644 index d2c74dd076..0000000000 --- a/src/Core/Settings/LoggingSettings/AdminLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class AdminLogLevelSettings : IAdminLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs deleted file mode 100644 index 7961ab7e3b..0000000000 --- a/src/Core/Settings/LoggingSettings/ApiLogLevelSettings.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class ApiLogLevelSettings : IApiLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; - public LogEventLevel IpRateLimit { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs deleted file mode 100644 index b9e53e6bca..0000000000 --- a/src/Core/Settings/LoggingSettings/BillingLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class BillingLogLevelSettings : IBillingLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; - public LogEventLevel Jobs { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs deleted file mode 100644 index 3201748550..0000000000 --- a/src/Core/Settings/LoggingSettings/EventsLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class EventsLogLevelSettings : IEventsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; -} diff --git a/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs deleted file mode 100644 index 5aff18a216..0000000000 --- a/src/Core/Settings/LoggingSettings/EventsProcessorLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class EventsProcessorLogLevelSettings : IEventsProcessorLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; -} diff --git a/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs deleted file mode 100644 index c7b73ba687..0000000000 --- a/src/Core/Settings/LoggingSettings/IconsLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class IconsLogLevelSettings : IIconsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs deleted file mode 100644 index a823cb5109..0000000000 --- a/src/Core/Settings/LoggingSettings/IdentityLogLevelSettings.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class IdentityLogLevelSettings : IIdentityLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; - public LogEventLevel IpRateLimit { get; set; } = LogEventLevel.Information; -} diff --git a/src/Core/Settings/LoggingSettings/LogLevelSettings.cs b/src/Core/Settings/LoggingSettings/LogLevelSettings.cs deleted file mode 100644 index 1af05ebfde..0000000000 --- a/src/Core/Settings/LoggingSettings/LogLevelSettings.cs +++ /dev/null @@ -1,16 +0,0 @@ - -namespace Bit.Core.Settings.LoggingSettings; - -public class LogLevelSettings : ILogLevelSettings -{ - public IBillingLogLevelSettings BillingSettings { get; set; } = new BillingLogLevelSettings(); - public IApiLogLevelSettings ApiSettings { get; set; } = new ApiLogLevelSettings(); - public IIdentityLogLevelSettings IdentitySettings { get; set; } = new IdentityLogLevelSettings(); - public IScimLogLevelSettings ScimSettings { get; set; } = new ScimLogLevelSettings(); - public ISsoLogLevelSettings SsoSettings { get; set; } = new SsoLogLevelSettings(); - public IAdminLogLevelSettings AdminSettings { get; set; } = new AdminLogLevelSettings(); - public IEventsLogLevelSettings EventsSettings { get; set; } = new EventsLogLevelSettings(); - public IEventsProcessorLogLevelSettings EventsProcessorSettings { get; set; } = new EventsProcessorLogLevelSettings(); - public IIconsLogLevelSettings IconsSettings { get; set; } = new IconsLogLevelSettings(); - public INotificationsLogLevelSettings NotificationsSettings { get; set; } = new NotificationsLogLevelSettings(); -} diff --git a/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs deleted file mode 100644 index 3494fbfcca..0000000000 --- a/src/Core/Settings/LoggingSettings/NotificationsLogLevelSettings.cs +++ /dev/null @@ -1,9 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class NotificationsLogLevelSettings : INotificationsLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; - public LogEventLevel IdentityToken { get; set; } = LogEventLevel.Fatal; -} diff --git a/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs deleted file mode 100644 index f297b17e95..0000000000 --- a/src/Core/Settings/LoggingSettings/ScimLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class ScimLogLevelSettings : IScimLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Warning; -} diff --git a/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs b/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs deleted file mode 100644 index 495ec41fd0..0000000000 --- a/src/Core/Settings/LoggingSettings/SsoLogLevelSettings.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Serilog.Events; - -namespace Bit.Core.Settings.LoggingSettings; - -public class SsoLogLevelSettings : ISsoLogLevelSettings -{ - public LogEventLevel Default { get; set; } = LogEventLevel.Error; -} diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index c7f7e3aff7..fa558f5963 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -150,17 +150,34 @@ public class ImportCiphersCommand : IImportCiphersCommand foreach (var collection in collections) { - if (!organizationCollectionsIds.Contains(collection.Id)) + // If the collection already exists, skip it + if (organizationCollectionsIds.Contains(collection.Id)) { - collection.SetNewId(); - newCollections.Add(collection); - newCollectionUsers.Add(new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = importingOrgUser.Id, - Manage = true - }); + continue; } + + // Create new collections if not already present + collection.SetNewId(); + newCollections.Add(collection); + + /* + * If the organization was created by a Provider, the organization may have zero members (users) + * In this situation importingOrgUser will be null, and accessing importingOrgUser.Id will + * result in a null reference exception. + * + * Avoid user assignment, but proceed with adding the collection. + */ + if (importingOrgUser == null) + { + continue; + } + + newCollectionUsers.Add(new CollectionUser + { + CollectionId = collection.Id, + OrganizationUserId = importingOrgUser.Id, + Manage = true + }); } // Create associations based on the newly assigned ids diff --git a/src/Core/Utilities/AssemblyHelpers.cs b/src/Core/Utilities/AssemblyHelpers.cs index 0cc01efdf3..03f7ff986d 100644 --- a/src/Core/Utilities/AssemblyHelpers.cs +++ b/src/Core/Utilities/AssemblyHelpers.cs @@ -1,46 +1,46 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using System.Diagnostics; using System.Reflection; namespace Bit.Core.Utilities; public static class AssemblyHelpers { - private static readonly IEnumerable _assemblyMetadataAttributes; - private static readonly AssemblyInformationalVersionAttribute _assemblyInformationalVersionAttributes; - private const string GIT_HASH_ASSEMBLY_KEY = "GitHash"; - private static string _version; - private static string _gitHash; + private static string? _version; + private static string? _gitHash; static AssemblyHelpers() { - _assemblyMetadataAttributes = Assembly.GetEntryAssembly().GetCustomAttributes(); - _assemblyInformationalVersionAttributes = Assembly.GetEntryAssembly().GetCustomAttribute(); - } - - public static string GetVersion() - { - if (string.IsNullOrWhiteSpace(_version)) + var assemblyInformationalVersionAttribute = typeof(AssemblyHelpers).Assembly.GetCustomAttribute(); + if (assemblyInformationalVersionAttribute == null) { - _version = _assemblyInformationalVersionAttributes.InformationalVersion; + Debug.Fail("The AssemblyInformationalVersionAttribute is expected to exist in this assembly, possibly its generation was turned off."); + return; } + var informationalVersion = assemblyInformationalVersionAttribute.InformationalVersion.AsSpan(); + + if (!informationalVersion.TrySplitBy('+', out var version, out var gitHash)) + { + // Treat the whole thing as the version + _version = informationalVersion.ToString(); + return; + } + + _version = version.ToString(); + if (gitHash.Length < 8) + { + return; + } + _gitHash = gitHash[..8].ToString(); + } + + public static string? GetVersion() + { return _version; } - public static string GetGitHash() + public static string? GetGitHash() { - if (string.IsNullOrWhiteSpace(_gitHash)) - { - try - { - _gitHash = _assemblyMetadataAttributes.Where(i => i.Key == GIT_HASH_ASSEMBLY_KEY).First().Value; - } - catch (Exception) - { } - } - return _gitHash; } } diff --git a/src/Core/Utilities/BillingHelpers.cs b/src/Core/Utilities/BillingHelpers.cs index e7ccfc3547..ef0fdf010b 100644 --- a/src/Core/Utilities/BillingHelpers.cs +++ b/src/Core/Utilities/BillingHelpers.cs @@ -1,13 +1,13 @@ -using Bit.Core.Entities; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; using Bit.Core.Exceptions; -using Bit.Core.Services; namespace Bit.Core.Utilities; public static class BillingHelpers { - internal static async Task AdjustStorageAsync(IPaymentService paymentService, IStorableSubscriber storableSubscriber, - short storageAdjustmentGb, string storagePlanId) + internal static async Task AdjustStorageAsync(IStripePaymentService paymentService, IStorableSubscriber storableSubscriber, + short storageAdjustmentGb, string storagePlanId, short baseStorageGb) { if (storableSubscriber == null) { @@ -30,9 +30,9 @@ public static class BillingHelpers } var newStorageGb = (short)(storableSubscriber.MaxStorageGb.Value + storageAdjustmentGb); - if (newStorageGb < 1) + if (newStorageGb < baseStorageGb) { - newStorageGb = 1; + newStorageGb = baseStorageGb; } if (newStorageGb > 100) @@ -48,7 +48,7 @@ public static class BillingHelpers "Delete some stored data first."); } - var additionalStorage = newStorageGb - 1; + var additionalStorage = newStorageGb - baseStorageGb; var paymentIntentClientSecret = await paymentService.AdjustStorageAsync(storableSubscriber, additionalStorage, storagePlanId); storableSubscriber.MaxStorageGb = newStorageGb; diff --git a/src/Core/Utilities/BitPayClient.cs b/src/Core/Utilities/BitPayClient.cs deleted file mode 100644 index cf241d5723..0000000000 --- a/src/Core/Utilities/BitPayClient.cs +++ /dev/null @@ -1,30 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Settings; - -namespace Bit.Core.Utilities; - -public class BitPayClient -{ - private readonly BitPayLight.BitPay _bpClient; - - public BitPayClient(GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) - { - _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, - globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); - } - } - - public Task GetInvoiceAsync(string id) - { - return _bpClient.GetInvoice(id); - } - - public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) - { - return _bpClient.CreateInvoice(invoice); - } -} diff --git a/src/Core/Utilities/CACHING.md b/src/Core/Utilities/CACHING.md new file mode 100644 index 0000000000..c29a14d751 --- /dev/null +++ b/src/Core/Utilities/CACHING.md @@ -0,0 +1,1123 @@ +# Bitwarden Server Caching + +Caching options available in Bitwarden's server. The server uses multiple caching layers and backends to balance performance, scalability, and operational simplicity across both cloud and self-hosted deployments. + +--- + +## Choosing a Caching Option + +Use this decision tree to identify the appropriate caching option for your feature: + +``` +Does your data need to be shared across all instances in a horizontally-scaled deployment? +├─ YES +│ │ +│ Do you need long-term persistence with TTL (days/weeks)? +│ ├─ YES → Use `IDistributedCache` with persistent keyed service +│ └─ NO → Use `ExtendedCache` +│ │ +│ Notes: +│ - With Redis configured: memory + distributed + backplane +│ - Without Redis: memory-only with stampede protection +│ - Provides fail-safe, eager refresh, circuit breaker +│ - For org/provider abilities: Use GetOrSetAsync with preloading pattern +│ +└─ NO (single instance or manual sync acceptable) + │ + Use `ExtendedCache` with memory-only mode (EnableDistributedCache = false) + │ + Notes: + - Same performance as raw IMemoryCache + - Built-in stampede protection, eager refresh, fail-safe + - "Free" Redis/backplane if needed at a later date (but not required) + - Only use specialized in-memory cache if ExtendedCache API doesn't fit + +*Stampede protection = prevents cache stampedes (multiple simultaneous requests for the same expired/missing key triggering redundant backend calls) +``` + +--- + +## Caching Options Overview + +| Option | Best For | Horizontal Scale | TTL Support | Backend Options | +| -------------------------------------- | ---------------------------------------------- | ---------------- | ----------- | ---------------------- | +| **ExtendedCache** | General-purpose caching with advanced features | ✅ Yes | ✅ Yes | Redis, Memory | +| **IDistributedCache** (default) | Short-lived key-value caching | ✅ Yes | ⚠️ Manual | Redis, SQL, EF | +| **IDistributedCache** (`"persistent"`) | Long-lived data with TTL | ✅ Yes | ✅ Yes | Cosmos, Redis, SQL, EF | +| **In-Memory Cache** | High-frequency reads, single instance | ❌ No | ⚠️ Manual | Memory | + +--- + +## `ExtendedCache` + +`ExtendedCache` is a wrapper around [FusionCache](https://github.com/ZiggyCreatures/FusionCache) that provides a simple way to register **named, isolated caches** with sensible defaults. The goal is to make it trivial for each subsystem or feature to have its own cache - with optional distributed caching and backplane support - without repeatedly wiring up FusionCache, Redis, and related infrastructure. + +Each named cache automatically receives: + +- Its own `FusionCache` instance +- Its own configuration (default or overridden) +- Its own key prefix +- Optional distributed store +- Optional backplane + +`ExtendedCache` supports three deployment modes: + +- **Memory-only caching** (with stampede protection: prevents multiple concurrent requests for the same key from hitting the backend) +- **Memory + distributed cache + backplane** using the **shared** application Redis +- **Memory + distributed cache + backplane** using a **fully isolated** Redis instance + +### When to Use + +- **General-purpose caching** for any domain data +- Features requiring **stampede protection** (when multiple concurrent requests for the same cache key should result in only a single backend call, with all requesters waiting for the same result) +- Data that benefits from **fail-safe mode** (serve stale data on backend failures) +- Multi-instance applications requiring **cache synchronization** via backplane +- You want **isolated cache configuration** per feature + +### Pros + +✅ **Advanced features out-of-the-box**: + +- Stampede protection (multiple requests for same key = single backend call) +- Fail-safe mode with stale data serving +- Adaptive caching with eager refresh +- Automatic backplane support for multi-instance sync +- Circuit breaker for backend failures + +✅ **Named, isolated caches**: Each feature gets its own cache instance with independent configuration + +✅ **Flexible deployment modes**: + +- Memory-only (development, testing) +- Memory + Redis (production cloud) +- Memory + isolated Redis (specialized features) + +✅ **Simple API**: Uses `FusionCache`'s intuitive `GetOrSet` pattern + +✅ **Built-in serialization**: Automatic JSON serialization/deserialization + +### Cons + +❌ Requires understanding of `FusionCache` configuration options + +❌ Slightly more overhead than raw `IDistributedCache` + +❌ IDistributedCache dependency for multi-instance deployments (typically Redis, but degrades gracefully to memory-only) + +### Example Usage + +**Note**: When using the shared Redis cache option (which is on by default, if the Redis connection string is configured), it is expected to call `services.AddDistributedCache(globalSettings)` **before** calling `AddExtendedCache`. The idea is to set up the distributed cache in our normal pattern and then "extend" it to include more functionality. + +#### 1. Register the cache (in Startup.cs): + +```csharp +// Option 1: Use default settings with shared Redis (if available) +services.AddDistributedCache(globalSettings); +services.AddExtendedCache("MyFeatureCache", globalSettings); + +// Option 2: Memory-only mode for high-performance single-instance caching +services.AddExtendedCache("MyFeatureCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + EnableDistributedCache = false, // Memory-only, same performance as IMemoryCache + Duration = TimeSpan.FromHours(1), + IsFailSafeEnabled = true, + EagerRefreshThreshold = 0.9 // Refresh at 90% of TTL +}); +// When EnableDistributedCache = false: +// - Uses memory-only caching (same performance as raw IMemoryCache) +// - Still provides stampede protection, eager refresh, fail-safe +// - Redis/backplane can be enabled later by setting EnableDistributedCache = true + +// Option 3: Override default settings with Redis +services.AddExtendedCache("MyFeatureCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromHours(1), + IsFailSafeEnabled = true, + FailSafeMaxDuration = TimeSpan.FromHours(2), + EagerRefreshThreshold = 0.9 // Refresh at 90% of TTL +}); + +// Option 4: Isolated Redis for specialized features +services.AddExtendedCache("SpecializedCache", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + UseSharedDistributedCache = false, + Redis = new GlobalSettings.ConnectionStringSettings + { + ConnectionString = "localhost:6379,ssl=false" + } +}); +// When configured this way: +// - A dedicated IConnectionMultiplexer is created +// - A dedicated IDistributedCache is created +// - A dedicated FusionCache backplane is created +// - All three are exposed to DI as keyed services (using the cache name as service key) +``` + +#### 2. Inject and use the cache: + +A named cache is retrieved via DI using keyed services (similar to how [IHttpClientFactory](https://learn.microsoft.com/en-us/aspnet/core/fundamentals/http-requests?view=aspnetcore-7.0#named-clients) works with named clients): + +```csharp +public class MyService +{ + private readonly IFusionCache _cache; + private readonly IItemRepository _itemRepository; + + // Option A: Inject via keyed service in constructor + public MyService( + [FromKeyedServices("MyFeatureCache")] IFusionCache cache, + IItemRepository itemRepository) + { + _cache = cache; + _itemRepository = itemRepository; + } + + // Option B: Request manually from service provider + // cache = provider.GetRequiredKeyedService(serviceKey: "MyFeatureCache") + + // Option C: Inject IFusionCacheProvider and request the named cache + // (similar to IHttpClientFactory pattern) + public MyService( + IFusionCacheProvider cacheProvider, + IItemRepository itemRepository) + { + _cache = cacheProvider.GetCache("MyFeatureCache"); + _itemRepository = itemRepository; + } + + public async Task GetItemAsync(Guid id) + { + return await _cache.GetOrSetAsync( + $"item:{id}", + async _ => await _itemRepository.GetByIdAsync(id), + options => options.SetDuration(TimeSpan.FromMinutes(30)) + ); + } +} +``` + +`ExtendedCache` doesn't change how `FusionCache` is used in code, which means all the functionality and full `FusionCache` API is available. See the [FusionCache docs](https://github.com/ZiggyCreatures/FusionCache/blob/main/docs/CoreMethods.md) for more details. + +### Specific Example: SSO Authorization Grants + +SSO authorization grants are **ephemeral, short-lived data** (typically ≤5 minutes) used to coordinate authorization flows across horizontally-scaled instances. `ExtendedCache` is ideal for this use case: + +```csharp +services.AddExtendedCache("SsoGrants", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromMinutes(5), + IsFailSafeEnabled = false // Re-initiate flow rather than serve stale grants +}); + +public class SsoAuthorizationService +{ + private readonly IFusionCache _cache; + + public SsoAuthorizationService([FromKeyedServices("SsoGrants")] IFusionCache cache) + { + _cache = cache; + } + + public async Task GetGrantAsync(string authorizationCode) + { + return await _cache.GetOrDefaultAsync($"sso:grant:{authorizationCode}"); + } + + public async Task StoreGrantAsync(string authorizationCode, SsoGrant grant) + { + await _cache.SetAsync($"sso:grant:{authorizationCode}", grant); + } +} +``` + +**Why `ExtendedCache` for SSO grants:** + +- **Not critical if lost**: User can re-initiate SSO flow +- **Lower latency**: Redis backplane is faster than persistent storage +- **Simpler infrastructure**: Reuses existing Redis connection +- **Horizontal scaling**: Redis backplane automatically synchronizes across instances + +### Backend Configuration + +`ExtendedCache` automatically uses the configured backend: + +**Cloud (Bitwarden-hosted)**: + +1. Redis (primary, if `GlobalSettings.DistributedCache.Redis.ConnectionString` configured) +2. Memory-only (fallback if Redis unavailable) + +**Self-hosted**: + +1. Redis (if configured in `appsettings.json`) +2. SQL Server / EF Cache (if `IDistributedCache` is registered and no Redis) +3. Memory-only (default fallback) + +> **Note**: ExtendedCache works seamlessly with any `IDistributedCache` backend. In self-hosted scenarios without Redis, you can configure ExtendedCache to use SQL Server or Entity Framework cache as its distributed layer. This provides local memory caching in front of the database cache, with the option to add Redis later if needed. You won't get the backplane (cross-instance invalidation) without Redis, but you still get stampede protection, eager refresh, and fail-safe mode. + +### Specific Example: Organization/Provider Abilities + +Organization and provider abilities are read extremely frequently (on every request that checks permissions) but change infrequently. `ExtendedCache` is ideal for this access pattern with its eager refresh and Redis backplane support: + +```csharp +services.AddExtendedCache("OrganizationAbilities", globalSettings, new GlobalSettings.ExtendedCacheSettings +{ + Duration = TimeSpan.FromMinutes(10), + EagerRefreshThreshold = 0.9, // Refresh at 90% of TTL + IsFailSafeEnabled = true, + FailSafeMaxDuration = TimeSpan.FromHours(1) // Serve stale data up to 1 hour on backend failures +}); + +public class OrganizationAbilityService +{ + private readonly IFusionCache _cache; + private readonly IOrganizationRepository _organizationRepository; + + public OrganizationAbilityService( + [FromKeyedServices("OrganizationAbilities")] IFusionCache cache, + IOrganizationRepository organizationRepository) + { + _cache = cache; + _organizationRepository = organizationRepository; + } + + public async Task> GetOrganizationAbilitiesAsync() + { + return await _cache.GetOrSetAsync>( + "all-org-abilities", + async _ => + { + var abilities = await _organizationRepository.GetManyAbilitiesAsync(); + return abilities.ToDictionary(a => a.Id); + } + ); + } + + public async Task GetOrganizationAbilityAsync(Guid orgId) + { + var abilities = await GetOrganizationAbilitiesAsync(); + abilities.TryGetValue(orgId, out var ability); + return ability; + } + + public async Task UpsertOrganizationAbilityAsync(Organization organization) + { + // Update database + await _organizationRepository.ReplaceAsync(organization); + + // Invalidate cache - with Redis backplane, this broadcasts to all instances + await _cache.RemoveAsync("all-org-abilities"); + } +} +``` + +**Why `ExtendedCache` for org/provider abilities:** + +- **High-frequency reads**: Every permission check reads abilities +- **Infrequent writes**: Abilities change rarely +- **Eager refresh**: Automatically refreshes at 90% of TTL to prevent cache misses +- **Fail-safe mode**: Serves stale data if database temporarily unavailable +- **Redis backplane**: Automatically invalidates across all instances when abilities change +- **No Service Bus dependency**: Simpler infrastructure (one Redis instead of Redis + Service Bus) + +### When NOT to Use + +- **Long-term persistent data** (days/weeks) - Use `IDistributedCache` with persistent keyed service for structured TTL support +- **Custom caching logic** - If ExtendedCache's API doesn't fit your use case, consider specialized in-memory cache + +--- + +## `IDistributedCache` + +`IDistributedCache` provides two service registrations for different use cases: + +1. **Default (unnamed) service** - For ephemeral, short-lived data +2. **Persistent cache** (keyed service: `"persistent"`) - For longer-lived data with structured TTL + +### When to Use + +**Default `IDistributedCache`**: + +- **Legacy code** already using `IDistributedCache` (consider migrating to `ExtendedCache`) +- **Third-party integrations** requiring `IDistributedCache` interface +- **ASP.NET Core session storage** (framework dependency) +- You have **specific requirements** that ExtendedCache doesn't support + +> **Note**: For new code, prefer `ExtendedCache` over default `IDistributedCache`. ExtendedCache can be configured with `EnableDistributedCache = false` to use memory-only caching with the same performance as raw `IMemoryCache`, while still providing stampede protection, fail-safe, and eager refresh. + +**Persistent cache** (keyed service: `"persistent"`): + +- **Critical data where memory loss would impact users** (refresh tokens, consent grants) +- **Long-lived structured data** with automatic TTL (days to weeks) +- **Long-lived OAuth/OIDC grants** that must survive application restarts +- **Payment intents** or workflow state that spans multiple requests +- Data requiring **automatic expiration** without manual cleanup +- **Large cache datasets** that benefit from external storage (e.g., thousands of refresh tokens) + +### Pros + +✅ **Standard ASP.NET Core interface**: Widely understood, well-documented + +✅ **Multiple backend support**: Redis, SQL Server, Entity Framework, Cosmos DB + +✅ **Automatic backend selection**: Picks the right backend based on configuration + +✅ **Simple API**: Just `Get`, `Set`, `Remove`, `Refresh` + +✅ **Minimal overhead**: No additional layers beyond the backend + +✅ **Keyed services**: Separate configurations for different use cases + +### Cons + +❌ **No stampede protection**: Multiple requests = multiple backend calls + +❌ **No fail-safe mode**: Backend unavailable = cache miss + +❌ **No backplane**: Manual cache invalidation across instances + +❌ **Manual serialization**: You handle JSON serialization (or use helpers) + +❌ **Manual TTL management** (default service): Must track expiration manually + +### Example Usage: Default (Ephemeral Data) + +#### 1. Registration (already done in Api, Admin, Billing, Events, EventsProcessor, Identity, and Notifications Startup.cs files): + +```csharp +services.AddDistributedCache(globalSettings); +``` + +#### 2. Inject and use for short-lived tokens: + +```csharp +public class TwoFactorService +{ + private readonly IDistributedCache _cache; + + public TwoFactorService(IDistributedCache cache) + { + _cache = cache; + } + + public async Task GetEmailTokenAsync(Guid userId) + { + var key = $"email-2fa:{userId}"; + var cached = await _cache.GetStringAsync(key); + return cached; + } + + public async Task SetEmailTokenAsync(Guid userId, string token) + { + var key = $"email-2fa:{userId}"; + await _cache.SetStringAsync(key, token, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(5) + }); + } +} +``` + +#### 3. Using JSON helpers: + +```csharp +using Bit.Core.Utilities; + +public async Task GetDataAsync(string key) +{ + return await _cache.TryGetValue(key); +} + +public async Task SetDataAsync(string key, MyData data) +{ + await _cache.SetAsync(key, data, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30) + }); +} +``` + +### Example Usage: Persistent (Long-Lived Data) + +The persistent cache is accessed via keyed service injection and is optimized for long-lived structured data with automatic TTL support. + +#### Specific Example: Payment Workflow State + +The persistent `IDistributedCache` service is appropriate for workflow state that spans multiple requests and needs automatic TTL cleanup. + +```csharp +public class SetupIntentDistributedCache( + [FromKeyedServices("persistent")] IDistributedCache distributedCache) : ISetupIntentCache +{ + public async Task Set(Guid subscriberId, string setupIntentId) + { + // Bidirectional mapping for payment flow + var bySubscriberIdCacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + var bySetupIntentIdCacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}"; + + // Note: No explicit TTL set here. Cosmos DB uses container-level TTL for automatic cleanup. + // In cloud, Cosmos TTL handles expiration. In self-hosted, the cache backend manages TTL. + await Task.WhenAll( + distributedCache.SetStringAsync(bySubscriberIdCacheKey, setupIntentId), + distributedCache.SetStringAsync(bySetupIntentIdCacheKey, subscriberId.ToString())); + } + + public async Task GetSetupIntentIdForSubscriber(Guid subscriberId) + { + var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + return await distributedCache.GetStringAsync(cacheKey); + } + + public async Task GetSubscriberIdForSetupIntent(string setupIntentId) + { + var cacheKey = $"subscriber_id_for_setup_intent_id_{setupIntentId}"; + var value = await distributedCache.GetStringAsync(cacheKey); + if (string.IsNullOrEmpty(value) || !Guid.TryParse(value, out var subscriberId)) + { + return null; + } + return subscriberId; + } + + public async Task RemoveSetupIntentForSubscriber(Guid subscriberId) + { + var cacheKey = $"setup_intent_id_for_subscriber_id_{subscriberId}"; + await distributedCache.RemoveAsync(cacheKey); + } +} +``` + +#### Specific Example: Long-Lived OAuth Grants + +Long-lived OAuth grants (refresh tokens, consent grants, device codes) use the persistent `IDistributedCache` in **cloud** and `IGrantRepository` as a **database fallback for self-hosted** when persistent cache is not configured: + +**Cloud (Bitwarden-hosted)**: + +- Uses persistent `IDistributedCache` directly (backed by Cosmos DB) +- Automatic TTL via Cosmos DB container-level TTL + +**Self-hosted**: + +- Uses `IGrantRepository` as a database fallback when persistent cache backend is not available +- Stores grants in `Grant` database table with automatic expiration + +**Grant type recommendations:** + +| Grant Type | Lifetime | Durability Requirement | Recommended Storage | Rationale | +| ------------------------ | ------------ | ---------------------- | ------------------- | ------------------------------------------------------------------------------------------- | +| SSO authorization codes | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | User can re-initiate SSO flow if code is lost; short lifetime limits exposure window | +| OIDC authorization codes | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | OAuth spec allows user to retry authorization; code is single-use and short-lived | +| PKCE code verifiers | ≤5 min | Ephemeral, can be lost | `ExtendedCache` | Tied to authorization code lifecycle; can be regenerated if authorization is retried | +| Refresh tokens | Days-weeks | Must persist | Persistent cache | Losing these forces user re-authentication; critical for seamless user experience | +| Consent grants | Weeks-months | Must persist | Persistent cache | User shouldn't have to re-consent frequently; loss degrades UX and trust | +| Device codes | Days | Must persist | Persistent cache | Device flow is async; losing codes breaks pending device authorizations with no recovery UX | + +### Backend Configuration + +The backend is automatically selected based on configuration and service key: + +#### Default `IDistributedCache` (ephemeral) + +**Cloud (Bitwarden-hosted)**: + +- **Redis** only (always configured in cloud environments) + +**Self-hosted priority order**: + +1. **Redis** (if `GlobalSettings.DistributedCache.Redis.ConnectionString` is configured) +2. **SQL Server Cache table** (if database provider is SQL Server) +3. **Entity Framework Cache table** (for PostgreSQL, MySQL, SQLite) + +#### Persistent cache (keyed service: `"persistent"`) + +**Cloud (Bitwarden-hosted)**: + +1. **Cosmos DB** (if `GlobalSettings.DistributedCache.Cosmos.ConnectionString` is configured) + - Database: `cache` + - Container: `default` +2. **Falls back to Redis** + +**Self-hosted priority order**: + +1. **Redis** (if configured) +2. **SQL Server Cache table** (if database provider is SQL Server) +3. **Entity Framework Cache table** (for PostgreSQL, MySQL, SQLite) + +### Backend Details + +#### Redis + +```csharp +services.AddStackExchangeRedisCache(options => +{ + options.Configuration = globalSettings.DistributedCache.Redis.ConnectionString; +}); +``` + +**Used for**: Cloud (always), self-hosted (if configured) + +- **Pros**: Fast, horizontally scalable, battle-tested +- **Cons**: Additional infrastructure dependency (self-hosted only) +- **TTL**: Via `AbsoluteExpiration` in cache entry options + +#### SQL Server Cache Table (Self-hosted only) + +```csharp +services.AddDistributedSqlServerCache(options => +{ + options.ConnectionString = globalSettings.SqlServer.ConnectionString; + options.SchemaName = "dbo"; + options.TableName = "Cache"; +}); +``` + +**Used for**: Self-hosted deployments without Redis + +- **Pros**: No additional infrastructure, works with existing database +- **Cons**: Slower than Redis, adds load to database, less scalable +- **TTL**: Via `ExpiresAtTime` and `AbsoluteExpiration` columns + +#### Entity Framework Cache (Self-hosted only) + +```csharp +services.AddSingleton(); +``` + +**Used for**: Self-hosted deployments with PostgreSQL, MySQL, or SQLite + +- **Pros**: Works with any EF-supported database (PostgreSQL, MySQL, SQLite) +- **Cons**: Slower than Redis, requires periodic expiration scanning, adds DB load + +**Features**: + +- Thread-safe operations with mutex locks +- Automatic expiration scanning every 30 minutes +- Sliding and absolute expiration support +- Provider-specific duplicate key handling + +**TTL**: Via `ExpiresAtTime` and `AbsoluteExpiration` columns with background scanning + +#### Cosmos DB (Cloud only, persistent cache) + +```csharp +services.AddKeyedSingleton("persistent", (provider, _) => +{ + return new CosmosCache(new CosmosCacheOptions + { + DatabaseName = "cache", + ContainerName = "default", + ClientBuilder = cosmosClientBuilder + }); +}); +``` + +**Used for**: Cloud persistent keyed service only + +- **Pros**: Globally distributed, automatic TTL support via container-level TTL, optimized for long-lived data +- **Cons**: Cloud-only, higher latency than Redis + +**TTL**: Cosmos DB container-level TTL (automatic cleanup, no scanning required) + +### Comparison: Default vs Persistent + +| Characteristic | Default | Persistent cache (`"persistent"`) | +| ----------------------- | ------------------------------ | ---------------------------------------------- | +| **Primary Use Case** | Ephemeral tokens, session data | Long-lived grants, workflow state | +| **Typical TTL** | 5-15 minutes | Hours to weeks | +| **User Impact if Lost** | Low (user can retry) | High (forces re-auth, interrupts workflows) | +| **Scale Consideration** | Small datasets | Large/growing datasets (thousands to millions) | +| **Cloud Backend** | Redis | Cosmos DB → Redis | +| **Self-Hosted Backend** | Redis → SQL → EF | Redis → SQL → EF | +| **Automatic Cleanup** | Manual expiration | Automatic TTL (Cosmos) | +| **Data Structure** | Simple key-value | Supports structured data | +| **Example** | 2FA codes, TOTP tokens | Refresh tokens, payment intents | + +### Choosing Default vs Persistent + +**Use Default when**: + +- Data lifetime < 15 minutes +- Ephemeral authentication tokens +- Simple key-value pairs +- Cost optimization is important +- Data loss on restart is acceptable + +**Use Persistent when**: + +- **Data loss would have user impact** (e.g., losing refresh tokens forces re-authentication) +- Data lifetime > 15 minutes +- **Cache size is large or growing** (thousands of items that exceed memory constraints) +- Structured data with relationships +- Automatic TTL cleanup is required +- Data must survive restarts and deployments +- Query capabilities are needed (via Cosmos DB) + +### When NOT to Use + +- **New general-purpose caching** - Use `ExtendedCache` instead for stampede protection, fail-safe, and backplane support +- **Organization/Provider abilities** - Use `ExtendedCache` with preloading pattern (see example above) +- **Short-lived ephemeral data** without persistence requirements - Use `ExtendedCache` (simpler, more features) + +--- + +## `IApplicationCacheService` (Deprecated) + +> **⚠️ Deprecated**: This service is being phased out in favor of `ExtendedCache`. New code should use `ExtendedCache` with the preloading pattern shown in the [Organization/Provider Abilities example](#specific-example-organizationprovider-abilities) above. + +### Background + +`IApplicationCacheService` was a **highly domain-specific caching service** built for Bitwarden organization and provider abilities. It used in-memory cache with Azure Service Bus for cross-instance invalidation. + +**Why it's being replaced:** + +- **Infrastructure complexity**: Required both Redis and Azure Service Bus +- **Limited applicability**: Only worked for org/provider abilities +- **Maintenance burden**: Custom implementation instead of leveraging standard caching primitives +- **Better alternative exists**: `ExtendedCache` with Redis backplane provides the same functionality with simpler infrastructure + +### Migration Path + +**Old approach** (IApplicationCacheService): + +- In-memory cache with periodic refresh +- Azure Service Bus for cross-instance invalidation +- Custom implementation for each domain + +**New approach** (ExtendedCache): + +- Memory + Redis distributed cache with backplane +- Eager refresh for automatic background updates +- Fail-safe mode for resilience +- Standard FusionCache API +- One Redis instance instead of Redis + Service Bus + +See the [Organization/Provider Abilities example](#specific-example-organizationprovider-abilities) for the recommended migration pattern. + +### When NOT to Use + +❌ **Do not use for new code** - Use `ExtendedCache` instead + +For existing code using `IApplicationCacheService`, plan migration to `ExtendedCache` using the pattern shown above. + +--- + +## Specialized In-Memory Cache + +> **Recommendation**: In most cases, use `ExtendedCache` with `EnableDistributedCache = false` instead of implementing a specialized in-memory cache. ExtendedCache provides the same memory-only performance with built-in stampede protection, eager refresh, and fail-safe capabilities. + +### When to Use + +Use a specialized in-memory cache only when: + +- **ExtendedCache's API doesn't fit** your specific use case +- **Custom eviction logic** is required beyond TTL-based expiration +- **Non-standard data structures** (e.g., priority queues, LRU with custom scoring) +- **Direct memory access patterns** that bypass serialization entirely + +For general high-performance caching, prefer `ExtendedCache` with memory-only mode. + +### Pros + +✅ **Maximum performance**: No serialization, no network calls, no locking overhead + +✅ **Simple implementation**: Just a `Dictionary` or `ConcurrentDictionary` + +✅ **Zero infrastructure**: No Redis, no database, no additional dependencies + +### Cons + +❌ **No horizontal scaling**: Each instance has separate cache state + +❌ **Manual invalidation**: No built-in cache invalidation mechanism + +❌ **Manual TTL**: You implement expiration logic + +❌ **Memory pressure**: Large datasets can cause GC issues + +### Example Implementation + +#### Simple in-memory cache: + +```csharp +public class MyFeatureCache +{ + private readonly ConcurrentDictionary> _cache = new(); + private readonly TimeSpan _defaultExpiration = TimeSpan.FromMinutes(30); + + public MyData GetOrAdd(string key, Func factory) + { + var entry = _cache.GetOrAdd(key, _ => new CacheEntry + { + Value = factory(), + ExpiresAt = DateTime.UtcNow + _defaultExpiration + }); + + // WARNING: This implementation has a race condition. Multiple threads detecting + // expiration simultaneously may each call TryRemove and then recursively call + // GetOrAdd, potentially causing the factory to execute multiple times. For + // production use cases requiring thread-safe expiration, consider using + // IMemoryCache with GetOrCreateAsync or ExtendedCache with stampede protection. + if (entry.ExpiresAt < DateTime.UtcNow) + { + _cache.TryRemove(key, out _); + return GetOrAdd(key, factory); + } + + return entry.Value; + } + + private class CacheEntry + { + public T Value { get; set; } + public DateTime ExpiresAt { get; set; } + } +} +``` + +#### Using `IMemoryCache`: + +```csharp +public class MyService +{ + private readonly IMemoryCache _memoryCache; + + public MyService(IMemoryCache memoryCache) + { + _memoryCache = memoryCache; + } + + public async Task GetDataAsync(string key) + { + return await _memoryCache.GetOrCreateAsync(key, async entry => + { + entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30); + entry.SetPriority(CacheItemPriority.High); + + return await _repository.GetDataAsync(key); + }); + } +} +``` + +### When NOT to Use + +- **Most general-purpose caching** - Use `ExtendedCache` with memory-only mode instead +- **Data requiring stampede protection** - Use `ExtendedCache` +- **Multi-instance deployments** requiring consistency - Use `ExtendedCache` with Redis +- **Long-lived OAuth grants** - Use persistent `IDistributedCache` + +> **Important**: Before implementing a custom in-memory cache, first try `ExtendedCache` with `EnableDistributedCache = false`. This gives you memory-only performance with automatic stampede protection, eager refresh, and fail-safe mode. + +--- + +## Backend Configuration + +### Configuration Priority + +The following table shows how different caching options resolve to storage backends based on configuration: + +| Cache Option | Cloud Backend | Self-Hosted Backend | Config Setting | +| -------------------------------------- | ------------------------- | --------------------------- | --------------------------------------------------------- | +| **ExtendedCache** | Redis → Memory | Redis → Memory | `GlobalSettings.DistributedCache.Redis.ConnectionString` | +| **IDistributedCache** (default) | Redis | Redis → SQL → EF | `GlobalSettings.DistributedCache.Redis.ConnectionString` | +| **IDistributedCache** (`"persistent"`) | Cosmos → Redis | Redis → SQL → EF | `GlobalSettings.DistributedCache.Cosmos.ConnectionString` | +| **OAuth Grants** (long-lived) | Persistent cache (Cosmos) | `IGrantRepository` (SQL/EF) | Various (see above) | + +### Redis Configuration + +**Cloud (Bitwarden-hosted)**: + +```json +{ + "GlobalSettings": { + "DistributedCache": { + "Redis": { + "ConnectionString": "redis.example.com:6379,ssl=true,password=..." + } + } + } +} +``` + +**Self-hosted** (`appsettings.json`): + +```json +{ + "globalSettings": { + "distributedCache": { + "redis": { + "connectionString": "localhost:6379" + } + } + } +} +``` + +### Cosmos DB Configuration + +**Persistent `IDistributedCache`** (cloud only): + +```json +{ + "GlobalSettings": { + "DistributedCache": { + "Cosmos": { + "ConnectionString": "AccountEndpoint=https://...;AccountKey=..." + } + } + } +} +``` + +- Database: `cache` +- Container: `default` +- Used for long-lived grants in cloud deployments + +### SQL Server Cache + +**Automatic configuration** (if SQL Server is database provider): + +```json +{ + "globalSettings": { + "sqlServer": { + "connectionString": "Server=...;Database=...;User Id=...;Password=..." + } + } +} +``` + +- Schema: `dbo` +- Table: `Cache` +- Migrations: Applied automatically + +### Entity Framework Cache + +**Automatic fallback** for PostgreSQL, MySQL, SQLite: + +No additional configuration required. Uses existing database connection. + +- Table: `Cache` +- Migrations: Applied automatically + +--- + +## Performance Considerations + +### Performance Characteristics + +| Backend | Read Latency | Write Latency | Throughput | +| -------------------- | ------------ | ------------- | ------------- | +| **Memory** | <1ms | <1ms | >100K req/s | +| **Redis** | 1-5ms | 1-5ms | 10K-50K req/s | +| **SQL Server** | 5-20ms | 10-50ms | 1K-5K req/s | +| **Entity Framework** | 5-20ms | 10-50ms | 1K-5K req/s | +| **Cosmos DB** | 5-15ms | 5-15ms | 10K+ req/s | + +**Note**: Latencies represent typical p95 values in production environments. Redis latencies assume same-datacenter deployment and include serialization overhead. Actual performance varies based on network topology, data size, and load. + +### Recommendations + +**For high-frequency reads (>1K req/s)**: + +1. `ExtendedCache` with Redis (cloud) +2. `ExtendedCache` memory-only (self-hosted, single instance) +3. Specialized in-memory cache (extreme performance requirements) + +**For moderate traffic (100-1K req/s)**: + +1. `ExtendedCache` with shared Redis +2. `IDistributedCache` with SQL Server cache + +**For low traffic (<100 req/s)**: + +1. `IDistributedCache` with SQL Server / EF cache +2. `ExtendedCache` memory-only + +--- + +## Testing Caches + +### Unit Testing + +**`ExtendedCache`**: + +```csharp +[Fact] +public async Task TestCacheHit() +{ + var services = new ServiceCollection(); + services.AddMemoryCache(); + services.AddExtendedCache("TestCache", new GlobalSettings + { + DistributedCache = new GlobalSettings.DistributedCacheSettings() + }); + + var provider = services.BuildServiceProvider(); + var cache = provider.GetRequiredKeyedService("TestCache"); + + await cache.SetAsync("key", "value"); + var result = await cache.GetOrDefaultAsync("key"); + + Assert.Equal("value", result); +} +``` + +**`IDistributedCache`**: + +```csharp +[Fact] +public async Task TestDistributedCache() +{ + var cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + + await cache.SetStringAsync("key", "value"); + var result = await cache.GetStringAsync("key"); + + Assert.Equal("value", result); +} +``` + +### Integration Testing + +**Example**: + +```csharp +[DatabaseTheory, DatabaseData] +public async Task Cache_ExpirationScanning_RemovesExpiredItems(IDistributedCache cache) +{ + // Set item with 1-second expiration + await cache.SetAsync("key", Encoding.UTF8.GetBytes("value"), new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromSeconds(1) + }); + + // Wait for expiration + await Task.Delay(TimeSpan.FromSeconds(2)); + + // Trigger expiration scan + var entityCache = cache as EntityFrameworkCache; + await entityCache.ScanForExpiredItemsAsync(); + + // Verify item is removed + var result = await cache.GetAsync("key"); + Assert.Null(result); +} +``` + +--- + +## Migration Examples + +Examples of migrating from one caching option to another: + +### From `IDistributedCache` → `ExtendedCache` + +**Before**: + +```csharp +// Registration +services.AddDistributedCache(globalSettings); + +// Constructor +public MyService(IDistributedCache cache, IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + var data = await _cache.TryGetValue(key); + if (data == null) + { + data = await _repository.GetAsync(key); + await _cache.SetAsync(key, data, new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(30) + }); + } + return data; +} +``` + +**After**: + +```csharp +// Registration +services.AddDistributedCache(globalSettings); +services.AddExtendedCache("MyFeature", globalSettings); + +// Constructor +public MyService( + [FromKeyedServices("MyFeature")] IFusionCache cache, + IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + return await _cache.GetOrSetAsync( + key, + async _ => await _repository.GetAsync(key), + options => options.SetDuration(TimeSpan.FromMinutes(30)) + ); +} +``` + +### From In-Memory → `ExtendedCache` + +**Before**: + +```csharp +// Field +private readonly ConcurrentDictionary _cache = new(); +private readonly IRepository _repository; + +// Constructor +public MyService(IRepository repository) +{ + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + if (_cache.TryGetValue(key, out var cached)) + { + return cached; + } + + var data = await _repository.GetAsync(key); + _cache.TryAdd(key, data); + return data; +} +``` + +**After**: + +```csharp +// Registration +services.AddExtendedCache("MyFeature", globalSettings); + +// Constructor +public MyService( + [FromKeyedServices("MyFeature")] IFusionCache cache, + IRepository repository) +{ + _cache = cache; + _repository = repository; +} + +// Usage +public async Task GetDataAsync(string key) +{ + return await _cache.GetOrSetAsync( + key, + async _ => await _repository.GetAsync(key) + ); +} +``` diff --git a/src/Core/Utilities/EmailValidation.cs b/src/Core/Utilities/EmailValidation.cs index f6832945af..10892f85c4 100644 --- a/src/Core/Utilities/EmailValidation.cs +++ b/src/Core/Utilities/EmailValidation.cs @@ -1,4 +1,6 @@ -using System.Text.RegularExpressions; +using System.Net.Mail; +using System.Text.RegularExpressions; +using Bit.Core.Exceptions; using MimeKit; namespace Bit.Core.Utilities; @@ -41,4 +43,22 @@ public static class EmailValidation return true; } + + /// + /// Extracts the domain portion from an email address and normalizes it to lowercase. + /// + /// The email address to extract the domain from. + /// The domain portion of the email address in lowercase (e.g., "example.com"). + /// Thrown when the email address format is invalid. + public static string GetDomain(string email) + { + try + { + return new MailAddress(email).Host.ToLower(); + } + catch (Exception ex) when (ex is FormatException || ex is ArgumentException) + { + throw new BadRequestException("Invalid email address format."); + } + } } diff --git a/src/Core/Utilities/EventIntegrationsCacheConstants.cs b/src/Core/Utilities/EventIntegrationsCacheConstants.cs new file mode 100644 index 0000000000..000a9c230e --- /dev/null +++ b/src/Core/Utilities/EventIntegrationsCacheConstants.cs @@ -0,0 +1,85 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; + +namespace Bit.Core.Utilities; + +/// +/// Provides cache key generation helpers and cache name constants for event integration–related entities. +/// +public static class EventIntegrationsCacheConstants +{ + /// + /// The base cache name used for storing event integration data. + /// + public const string CacheName = "EventIntegrations"; + + /// + /// Duration TimeSpan for adding OrganizationIntegrationConfigurationDetails to the cache. + /// + public static readonly TimeSpan DurationForOrganizationIntegrationConfigurationDetails = TimeSpan.FromDays(1); + + /// + /// Builds a deterministic cache key for a . + /// + /// The unique identifier of the group. + /// + /// A cache key for this Group. + /// + public static string BuildCacheKeyForGroup(Guid groupId) => + $"Group:{groupId:N}"; + + /// + /// Builds a deterministic cache key for an . + /// + /// The unique identifier of the organization. + /// + /// A cache key for the Organization. + /// + public static string BuildCacheKeyForOrganization(Guid organizationId) => + $"Organization:{organizationId:N}"; + + /// + /// Builds a deterministic cache key for an organization user . + /// + /// The unique identifier of the organization to which the user belongs. + /// The unique identifier of the user. + /// + /// A cache key for the user. + /// + public static string BuildCacheKeyForOrganizationUser(Guid organizationId, Guid userId) => + $"OrganizationUserUserDetails:{organizationId:N}:{userId:N}"; + + /// + /// Builds a deterministic cache key for an organization's integration configuration details + /// . + /// + /// The unique identifier of the organization. + /// The of the integration. + /// The specific of the event configured. + /// + /// A cache key for the configuration details. + /// + public static string BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + Guid organizationId, + IntegrationType integrationType, + EventType eventType + ) => $"OrganizationIntegrationConfigurationDetails:{organizationId:N}:{integrationType}:{eventType}"; + + /// + /// Builds a deterministic tag for tagging an organization's integration configuration details. This tag is then + /// used to tag all of the that result from this + /// integration, which allows us to remove all relevant entries when an integration is changed or removed. + /// + /// The unique identifier of the organization to which the user belongs. + /// The of the integration. + /// + /// A cache tag to use for the configuration details. + /// + public static string BuildCacheTagForOrganizationIntegration( + Guid organizationId, + IntegrationType integrationType + ) => $"OrganizationIntegration:{organizationId:N}:{integrationType}"; +} diff --git a/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs b/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs new file mode 100644 index 0000000000..f287f64e54 --- /dev/null +++ b/src/Core/Utilities/ExtendedCacheServiceCollectionExtensions.cs @@ -0,0 +1,186 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.StackExchangeRedis; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; +using ZiggyCreatures.Caching.Fusion; +using ZiggyCreatures.Caching.Fusion.Backplane; +using ZiggyCreatures.Caching.Fusion.Backplane.StackExchangeRedis; +using ZiggyCreatures.Caching.Fusion.Serialization.SystemTextJson; + +namespace Microsoft.Extensions.DependencyInjection; + +public static class ExtendedCacheServiceCollectionExtensions +{ + /// + /// Adds a new, named Fusion Cache to the service + /// collection. If an existing cache of the same name is found, it will do nothing.
    + ///
    + /// Note: When re-using an existing distributed cache, it is expected to call this method after calling + /// services.AddDistributedCache(globalSettings)
    This ensures that DI correctly finds + /// and re-uses the shared distributed cache infrastructure.
    + ///
    + /// Backplane: Cross-instance cache invalidation is only available when using Redis. + /// Non-Redis distributed caches operate with eventual consistency across multiple instances. + ///
    + public static IServiceCollection AddExtendedCache( + this IServiceCollection services, + string cacheName, + GlobalSettings globalSettings, + GlobalSettings.ExtendedCacheSettings? settings = null) + { + settings ??= globalSettings.DistributedCache.DefaultExtendedCache; + if (settings is null || string.IsNullOrEmpty(cacheName)) + { + return services; + } + + // If a cache already exists with this key, do nothing + if (services.Any(s => s.ServiceType == typeof(IFusionCache) && + s.ServiceKey?.Equals(cacheName) == true)) + { + return services; + } + + if (services.All(s => s.ServiceType != typeof(FusionCacheSystemTextJsonSerializer))) + { + services.AddFusionCacheSystemTextJsonSerializer(); + } + var fusionCacheBuilder = services + .AddFusionCache(cacheName) + .WithCacheKeyPrefix($"{cacheName}:") + .AsKeyedServiceByCacheName() + .WithOptions(opt => + { + opt.DistributedCacheCircuitBreakerDuration = settings.DistributedCacheCircuitBreakerDuration; + }) + .WithDefaultEntryOptions(new FusionCacheEntryOptions + { + Duration = settings.Duration, + IsFailSafeEnabled = settings.IsFailSafeEnabled, + FailSafeMaxDuration = settings.FailSafeMaxDuration, + FailSafeThrottleDuration = settings.FailSafeThrottleDuration, + EagerRefreshThreshold = settings.EagerRefreshThreshold, + FactorySoftTimeout = settings.FactorySoftTimeout, + FactoryHardTimeout = settings.FactoryHardTimeout, + DistributedCacheSoftTimeout = settings.DistributedCacheSoftTimeout, + DistributedCacheHardTimeout = settings.DistributedCacheHardTimeout, + AllowBackgroundDistributedCacheOperations = settings.AllowBackgroundDistributedCacheOperations, + JitterMaxDuration = settings.JitterMaxDuration + }) + .WithRegisteredSerializer(); + + if (!settings.EnableDistributedCache) + return services; + + if (settings.UseSharedDistributedCache) + { + if (!CoreHelpers.SettingHasValue(globalSettings.DistributedCache.Redis.ConnectionString)) + { + // Using Shared Non-Redis Distributed Cache: + // 1. Assume IDistributedCache is already registered (e.g., Cosmos, SQL Server) + // 2. Backplane not supported (Redis-only feature, requires pub/sub) + + fusionCacheBuilder + .TryWithRegisteredDistributedCache(); + + return services; + } + + // Using Shared Redis, TryAdd and reuse all pieces (multiplexer, distributed cache and backplane) + + services.TryAddSingleton(sp => + CreateConnectionMultiplexer(sp, cacheName, globalSettings.DistributedCache.Redis.ConnectionString)); + + services.TryAddSingleton(sp => + { + var mux = sp.GetRequiredService(); + return new RedisCache(new RedisCacheOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + }); + + services.TryAddSingleton(sp => + { + var mux = sp.GetRequiredService(); + return new RedisBackplane(new RedisBackplaneOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + }); + + fusionCacheBuilder + .WithRegisteredDistributedCache() + .WithRegisteredBackplane(); + + return services; + } + + // Using keyed Distributed Cache. Create/Reuse all pieces as keyed services. + + if (!CoreHelpers.SettingHasValue(settings.Redis.ConnectionString)) + { + // Using Keyed Non-Redis Distributed Cache: + // 1. Assume IDistributedCache (e.g., Cosmos, SQL Server) is already registered with cacheName as key + // 2. Backplane not supported (Redis-only feature, requires pub/sub) + + fusionCacheBuilder + .TryWithRegisteredKeyedDistributedCache(serviceKey: cacheName); + + return services; + } + + // Using Keyed Redis: TryAdd and reuse all pieces (multiplexer, distributed cache and backplane) + + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => CreateConnectionMultiplexer(sp, cacheName, settings.Redis.ConnectionString) + ); + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => + { + var mux = sp.GetRequiredKeyedService(cacheName); + return new RedisCache(new RedisCacheOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + } + ); + services.TryAddKeyedSingleton( + cacheName, + (sp, _) => + { + var mux = sp.GetRequiredKeyedService(cacheName); + return new RedisBackplane(new RedisBackplaneOptions + { + ConnectionMultiplexerFactory = () => Task.FromResult(mux) + }); + } + ); + + fusionCacheBuilder + .WithRegisteredKeyedDistributedCacheByCacheName() + .WithRegisteredKeyedBackplaneByCacheName(); + + return services; + } + + private static ConnectionMultiplexer CreateConnectionMultiplexer(IServiceProvider sp, string cacheName, + string connectionString) + { + try + { + return ConnectionMultiplexer.Connect(connectionString); + } + catch (Exception ex) + { + var logger = sp.GetService(); + logger?.LogError(ex, "Failed to connect to Redis for cache {CacheName}", cacheName); + throw; + } + } +} diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index 54bd84df6f..b950e30d5d 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -1,165 +1,78 @@ -using System.Security.Cryptography.X509Certificates; -using Bit.Core.Settings; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -using Serilog; -using Serilog.Events; -using Serilog.Sinks.Syslog; namespace Bit.Core.Utilities; public static class LoggerFactoryExtensions { - public static void UseSerilog( - this IApplicationBuilder appBuilder, - IWebHostEnvironment env, - IHostApplicationLifetime applicationLifetime, - GlobalSettings globalSettings) + /// + /// + /// + /// + /// + public static IHostBuilder AddSerilogFileLogging(this IHostBuilder hostBuilder) { - if (env.IsDevelopment() && !globalSettings.EnableDevLogging) + return hostBuilder.ConfigureLogging((context, logging) => { - return; - } - - applicationLifetime.ApplicationStopped.Register(Log.CloseAndFlush); - } - - public static ILoggingBuilder AddSerilog( - this ILoggingBuilder builder, - WebHostBuilderContext context, - Func? filter = null) - { - var globalSettings = new GlobalSettings(); - ConfigurationBinder.Bind(context.Configuration.GetSection("GlobalSettings"), globalSettings); - - if (context.HostingEnvironment.IsDevelopment() && !globalSettings.EnableDevLogging) - { - return builder; - } - - bool inclusionPredicate(LogEvent e) - { - if (filter == null) + if (context.HostingEnvironment.IsDevelopment()) { - return true; + return; } - var eventId = e.Properties.TryGetValue("EventId", out var eventIdValue) ? eventIdValue.ToString() : null; - if (eventId?.Contains(Constants.BypassFiltersEventId.ToString()) ?? false) + + // If they have begun using the new settings location, use that + if (!string.IsNullOrEmpty(context.Configuration["Logging:PathFormat"])) { - return true; - } - return filter(e, globalSettings); - } - - var logSentryWarning = false; - var logSyslogWarning = false; - - // Path format is the only required option for file logging, we will use that as - // the keystone for if they have configured the new location. - var newPathFormat = context.Configuration["Logging:PathFormat"]; - - var config = new LoggerConfiguration() - .MinimumLevel.Verbose() - .Enrich.FromLogContext() - .Filter.ByIncludingOnly(inclusionPredicate); - - if (CoreHelpers.SettingHasValue(globalSettings.Sentry.Dsn)) - { - config.WriteTo.Sentry(globalSettings.Sentry.Dsn) - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } - else if (CoreHelpers.SettingHasValue(globalSettings.Syslog.Destination)) - { - logSyslogWarning = true; - // appending sitename to project name to allow easier identification in syslog. - var appName = $"{globalSettings.SiteName}-{globalSettings.ProjectName}"; - if (globalSettings.Syslog.Destination.Equals("local", StringComparison.OrdinalIgnoreCase)) - { - config.WriteTo.LocalSyslog(appName); - } - else if (Uri.TryCreate(globalSettings.Syslog.Destination, UriKind.Absolute, out var syslogAddress)) - { - // Syslog's standard port is 514 (both UDP and TCP). TLS does not have a standard port, so assume 514. - int port = syslogAddress.Port >= 0 - ? syslogAddress.Port - : 514; - - if (syslogAddress.Scheme.Equals("udp")) - { - config.WriteTo.UdpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tcp")) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName); - } - else if (syslogAddress.Scheme.Equals("tls")) - { - if (CoreHelpers.SettingHasValue(globalSettings.Syslog.CertificateThumbprint)) - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - useTls: true, - certProvider: new CertificateStoreProvider(StoreName.My, StoreLocation.CurrentUser, - globalSettings.Syslog.CertificateThumbprint)); - } - else - { - config.WriteTo.TcpSyslog(syslogAddress.Host, port, appName, - useTls: true, - certProvider: new CertificateFileProvider(globalSettings.Syslog.CertificatePath, - globalSettings.Syslog?.CertificatePassword ?? string.Empty)); - } - } - } - } - else if (!string.IsNullOrEmpty(newPathFormat)) - { - // Use new location - builder.AddFile(context.Configuration.GetSection("Logging")); - } - else if (CoreHelpers.SettingHasValue(globalSettings.LogDirectory)) - { - if (globalSettings.LogRollBySizeLimit.HasValue) - { - var pathFormat = Path.Combine(globalSettings.LogDirectory, $"{globalSettings.ProjectName.ToLowerInvariant()}.log"); - if (globalSettings.LogDirectoryByProject) - { - pathFormat = Path.Combine(globalSettings.LogDirectory, globalSettings.ProjectName, "log.txt"); - } - config.WriteTo.File(pathFormat, rollOnFileSizeLimit: true, - fileSizeLimitBytes: globalSettings.LogRollBySizeLimit); + logging.AddFile(context.Configuration.GetSection("Logging")); } else { - var pathFormat = Path.Combine(globalSettings.LogDirectory, $"{globalSettings.ProjectName.ToLowerInvariant()}_{{Date}}.log"); - if (globalSettings.LogDirectoryByProject) + var globalSettingsSection = context.Configuration.GetSection("GlobalSettings"); + var loggingOptions = new LegacyFileLoggingOptions(); + globalSettingsSection.Bind(loggingOptions); + + if (string.IsNullOrWhiteSpace(loggingOptions.LogDirectory)) { - pathFormat = Path.Combine(globalSettings.LogDirectory, globalSettings.ProjectName, "{Date}.txt"); + return; + } + + var projectName = loggingOptions.ProjectName + ?? context.HostingEnvironment.ApplicationName; + + if (loggingOptions.LogRollBySizeLimit.HasValue) + { + var pathFormat = loggingOptions.LogDirectoryByProject + ? Path.Combine(loggingOptions.LogDirectory, projectName, "log.txt") + : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}.log"); + + logging.AddFile( + pathFormat: pathFormat, + fileSizeLimitBytes: loggingOptions.LogRollBySizeLimit.Value + ); + } + else + { + var pathFormat = loggingOptions.LogDirectoryByProject + ? Path.Combine(loggingOptions.LogDirectory, projectName, "{Date}.txt") + : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}_{{Date}}.log"); + + logging.AddFile( + pathFormat: pathFormat + ); } - config.WriteTo.RollingFile(pathFormat); } - config - .Enrich.FromLogContext() - .Enrich.WithProperty("Project", globalSettings.ProjectName); - } + }); + } - var serilog = config.CreateLogger(); - - if (logSentryWarning) - { - serilog.Warning("Sentry for logging has been deprecated. Read more: https://btwrdn.com/log-deprecation"); - } - - if (logSyslogWarning) - { - serilog.Warning("Syslog for logging has been deprecated. Read more: https://btwrdn.com/log-deprecation"); - } - - builder.AddSerilog(serilog); - - return builder; + /// + /// Our own proprietary options that we've always supported in `GlobalSettings` configuration section. + /// + private class LegacyFileLoggingOptions + { + public string? ProjectName { get; set; } + public string? LogDirectory { get; set; } = "/etc/bitwarden/logs"; + public bool LogDirectoryByProject { get; set; } = true; + public long? LogRollBySizeLimit { get; set; } } } diff --git a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs index 6709bbb271..300c30641e 100644 --- a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs +++ b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs @@ -17,6 +17,6 @@ public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute var logger = context.HttpContext.RequestServices .GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); } } diff --git a/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs b/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs new file mode 100644 index 0000000000..a8208844a8 --- /dev/null +++ b/src/Core/Utilities/RequireLowerEnvironmentAttribute.cs @@ -0,0 +1,24 @@ +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.Extensions.Hosting; + +namespace Bit.Core.Utilities; + +/// +/// Authorization attribute that restricts controller/action access to Development and QA environments only. +/// Returns 404 Not Found in all other environments. +/// +public class RequireLowerEnvironmentAttribute() : TypeFilterAttribute(typeof(LowerEnvironmentFilter)) +{ + private class LowerEnvironmentFilter(IWebHostEnvironment environment) : IAuthorizationFilter + { + public void OnAuthorization(AuthorizationFilterContext context) + { + if (!environment.IsDevelopment() && !environment.IsEnvironment("QA")) + { + context.Result = new NotFoundResult(); + } + } + } +} diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 1ddd926569..f0fbd80c38 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -1,13 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using System.Collections.Immutable; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations.OrganizationUsers; -using Bit.Core.Models.StaticStore; namespace Bit.Core.Utilities; @@ -110,55 +104,7 @@ public static class StaticStore GlobalDomains.Add(GlobalEquivalentDomainsType.Atlassian, new List { "atlassian.com", "bitbucket.org", "trello.com", "statuspage.io", "atlassian.net", "jira.com" }); GlobalDomains.Add(GlobalEquivalentDomainsType.Pinterest, new List { "pinterest.com", "pinterest.com.au", "pinterest.cl", "pinterest.de", "pinterest.dk", "pinterest.es", "pinterest.fr", "pinterest.co.uk", "pinterest.jp", "pinterest.co.kr", "pinterest.nz", "pinterest.pt", "pinterest.se" }); #endregion - - Plans = new List - { - new EnterprisePlan(true), - new EnterprisePlan(false), - new TeamsStarterPlan(), - new TeamsPlan(true), - new TeamsPlan(false), - - new Enterprise2023Plan(true), - new Enterprise2023Plan(false), - new Enterprise2020Plan(true), - new Enterprise2020Plan(false), - new TeamsStarterPlan2023(), - new Teams2023Plan(true), - new Teams2023Plan(false), - new Teams2020Plan(true), - new Teams2020Plan(false), - new FamiliesPlan(), - new FreePlan(), - new CustomPlan(), - - new Enterprise2019Plan(true), - new Enterprise2019Plan(false), - new Teams2019Plan(true), - new Teams2019Plan(false), - new Families2019Plan(), - }.ToImmutableList(); } public static IDictionary> GlobalDomains { get; set; } - [Obsolete("Use PricingClient.ListPlans to retrieve all plans.")] - public static IEnumerable Plans { get; } - public static IEnumerable SponsoredPlans { get; set; } = new[] - { - new SponsoredPlan - { - PlanSponsorshipType = PlanSponsorshipType.FamiliesForEnterprise, - SponsoredProductTierType = ProductTierType.Families, - SponsoringProductTierType = ProductTierType.Enterprise, - StripePlanId = "2021-family-for-enterprise-annually", - UsersCanSponsor = (OrganizationUserOrganizationDetails org) => - org.PlanType.GetProductTier() == ProductTierType.Enterprise, - } - }; - - [Obsolete("Use PricingClient.GetPlan to retrieve a plan.")] - public static Plan GetPlan(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType); - - public static SponsoredPlan GetSponsoredPlan(PlanSponsorshipType planSponsorshipType) => - SponsoredPlans.FirstOrDefault(p => p.PlanSponsorshipType == planSponsorshipType); } diff --git a/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs b/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs index fbd553d772..bb3bafb230 100644 --- a/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs +++ b/src/Core/Vault/Authorization/Permissions/NormalCipherPermissions.cs @@ -14,7 +14,7 @@ public class NormalCipherPermissions throw new Exception("Cipher needs to belong to a user or an organization."); } - if (user.Id == cipherDetails.UserId) + if (cipherDetails.OrganizationId == null && user.Id == cipherDetails.UserId) { return true; } diff --git a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs index 4b88f1c0e8..0be3bbd545 100644 --- a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs +++ b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs @@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository /// The id of the organization /// A collection of security task metrics Task GetTaskMetricsAsync(Guid organizationId); + + /// + /// Marks all tasks associated with the respective ciphers as complete. + /// + /// Collection of cipher IDs + Task MarkAsCompleteByCipherIds(IEnumerable cipherIds); } diff --git a/src/Core/Vault/Services/ICipherService.cs b/src/Core/Vault/Services/ICipherService.cs index ffd79e9381..765dae30c1 100644 --- a/src/Core/Vault/Services/ICipherService.cs +++ b/src/Core/Vault/Services/ICipherService.cs @@ -13,9 +13,9 @@ public interface ICipherService Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, bool skipPermissionCheck = false); Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false); + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, string attachmentId, Guid organizationShareId); Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index f132588e37..fa2cfbb209 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -2,6 +2,7 @@ #nullable disable using System.Text.Json; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; @@ -33,6 +34,7 @@ public class CipherService : ICipherService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly ISecurityTaskRepository _securityTaskRepository; private readonly IPushNotificationService _pushService; private readonly IAttachmentStorageService _attachmentStorageService; private readonly IEventService _eventService; @@ -53,6 +55,7 @@ public class CipherService : ICipherService IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, ICollectionCipherRepository collectionCipherRepository, + ISecurityTaskRepository securityTaskRepository, IPushNotificationService pushService, IAttachmentStorageService attachmentStorageService, IEventService eventService, @@ -71,6 +74,7 @@ public class CipherService : ICipherService _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _collectionCipherRepository = collectionCipherRepository; + _securityTaskRepository = securityTaskRepository; _pushService = pushService; _attachmentStorageService = attachmentStorageService; _eventService = eventService; @@ -113,7 +117,7 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; await _cipherRepository.ReplaceAsync(cipher); await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated); @@ -168,7 +172,7 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; await ValidateChangeInCollectionsAsync(cipher, collectionIds, savingUserId); await ValidateViewPasswordUserAsync(cipher); @@ -196,8 +200,9 @@ public class CipherService : ICipherService } public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -232,8 +237,9 @@ public class CipherService : ICipherService } public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false) + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -713,13 +719,7 @@ public class CipherService : ICipherService cipherDetails.DeletedDate = cipherDetails.RevisionDate = DateTime.UtcNow; - if (cipherDetails.ArchivedDate.HasValue) - { - // If the cipher was archived, clear the archived date when soft deleting - // If a user were to restore an archived cipher, it should go back to the vault not the archive vault - cipherDetails.ArchivedDate = null; - } - + await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]); await _cipherRepository.UpsertAsync(cipherDetails); await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); @@ -746,6 +746,8 @@ public class CipherService : ICipherService await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); } + await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id)); + var events = deletingCiphers.Select(c => new Tuple(c, EventType.Cipher_SoftDeleted, null)); foreach (var eventsBatch in events.Chunk(100)) @@ -859,7 +861,7 @@ public class CipherService : ICipherService return NormalCipherPermissions.CanRestore(user, cipher, organizationAbility); } - private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) + private void ValidateCipherLastKnownRevisionDate(Cipher cipher, DateTime? lastKnownRevisionDate) { if (cipher.Id == default || !lastKnownRevisionDate.HasValue) { @@ -982,11 +984,6 @@ public class CipherService : ICipherService throw new BadRequestException("One or more ciphers do not belong to you."); } - if (cipher.ArchivedDate.HasValue) - { - throw new BadRequestException("Cipher cannot be shared with organization because it is archived."); - } - var attachments = cipher.GetAttachments(); var hasAttachments = attachments?.Any() ?? false; var org = await _organizationRepository.GetByIdAsync(organizationId); @@ -996,18 +993,41 @@ public class CipherService : ICipherService throw new BadRequestException("Could not find organization."); } - if (hasAttachments && !org.MaxStorageGb.HasValue) + if (!await IgnoreStorageLimitsOnMigrationAsync(sharingUserId, org)) { - throw new BadRequestException("This organization cannot use attachments."); + if (hasAttachments && !org.MaxStorageGb.HasValue) + { + throw new BadRequestException("This organization cannot use attachments."); + } + + var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; + if (org.StorageBytesRemaining() < storageAdjustment) + { + throw new BadRequestException("Not enough storage available for this organization."); + } } - var storageAdjustment = attachments?.Sum(a => a.Value.Size) ?? 0; - if (org.StorageBytesRemaining() < storageAdjustment) + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); + } + + /// + /// Checks if the storage limit for the org should be ignored due to the Organization Data Ownership Policy + /// + private async Task IgnoreStorageLimitsOnMigrationAsync(Guid userId, Organization organization) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.MigrateMyVaultToMyItems)) { - throw new BadRequestException("Not enough storage available for this organization."); + return false; } - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + if (!organization.UsePolicies) + { + return false; + } + + var requirement = await _policyRequirementQuery.GetAsync(userId); + + return requirement.IgnoreStorageLimitsOnMigration(organization.Id); } private async Task ValidateViewPasswordUserAsync(Cipher cipher) @@ -1026,11 +1046,8 @@ public class CipherService : ICipherService var existingCipherData = DeserializeCipherData(existingCipher); var newCipherData = DeserializeCipherData(cipher); - // "hidden password" users may not add cipher key encryption - if (existingCipher.Key == null && cipher.Key != null) - { - throw new BadRequestException("You do not have permission to add cipher key encryption."); - } + // For hidden-password users, never allow Key to change at all. + cipher.Key = existingCipher.Key; // Keep only non-hidden fileds from the new cipher var nonHiddenFields = newCipherData.Fields?.Where(f => f.Type != FieldType.Hidden) ?? []; // Get hidden fields from the existing cipher diff --git a/src/Events/Controllers/CollectController.cs b/src/Events/Controllers/CollectController.cs index d7fbbbc595..3902522665 100644 --- a/src/Events/Controllers/CollectController.cs +++ b/src/Events/Controllers/CollectController.cs @@ -21,23 +21,21 @@ public class CollectController : Controller private readonly IEventService _eventService; private readonly ICipherRepository _cipherRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IFeatureService _featureService; - private readonly IApplicationCacheService _applicationCacheService; + private readonly IOrganizationUserRepository _organizationUserRepository; public CollectController( ICurrentContext currentContext, IEventService eventService, ICipherRepository cipherRepository, IOrganizationRepository organizationRepository, - IFeatureService featureService, - IApplicationCacheService applicationCacheService) + IOrganizationUserRepository organizationUserRepository + ) { _currentContext = currentContext; _eventService = eventService; _cipherRepository = cipherRepository; _organizationRepository = organizationRepository; - _featureService = featureService; - _applicationCacheService = applicationCacheService; + _organizationUserRepository = organizationUserRepository; } [HttpPost] @@ -47,8 +45,10 @@ public class CollectController : Controller { return new BadRequestResult(); } + var cipherEvents = new List>(); var ciphersCache = new Dictionary(); + foreach (var eventModel in model) { switch (eventModel.Type) @@ -57,6 +57,25 @@ public class CollectController : Controller case EventType.User_ClientExportedVault: await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); break; + + case EventType.Organization_ItemOrganization_Accepted: + case EventType.Organization_ItemOrganization_Declined: + if (!eventModel.OrganizationId.HasValue || !_currentContext.UserId.HasValue) + { + continue; + } + + var orgUser = await _organizationUserRepository.GetByOrganizationAsync(eventModel.OrganizationId.Value, _currentContext.UserId.Value); + + if (orgUser == null) + { + continue; + } + + await _eventService.LogOrganizationUserEventAsync(orgUser, eventModel.Type, eventModel.Date); + + continue; + // Cipher events case EventType.Cipher_ClientAutofilled: case EventType.Cipher_ClientCopiedHiddenField: @@ -71,7 +90,8 @@ public class CollectController : Controller { continue; } - Cipher cipher = null; + + Cipher cipher; if (ciphersCache.TryGetValue(eventModel.CipherId.Value, out var cachedCipher)) { cipher = cachedCipher; @@ -81,6 +101,7 @@ public class CollectController : Controller cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, _currentContext.UserId.Value); } + if (cipher == null) { // When the user cannot access the cipher directly, check if the organization allows for @@ -91,29 +112,44 @@ public class CollectController : Controller } cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value); + if (cipher == null) + { + continue; + } + var cipherBelongsToOrg = cipher.OrganizationId == eventModel.OrganizationId; var org = _currentContext.GetOrganization(eventModel.OrganizationId.Value); - if (!cipherBelongsToOrg || org == null || cipher == null) + if (!cipherBelongsToOrg || org == null) { continue; } } + ciphersCache.TryAdd(eventModel.CipherId.Value, cipher); cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); break; + case EventType.Organization_ClientExportedVault: if (!eventModel.OrganizationId.HasValue) { continue; } + var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); + if (organization == null) + { + continue; + } + await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); break; + default: continue; } } + if (cipherEvents.Any()) { foreach (var eventsBatch in cipherEvents.Chunk(50)) @@ -121,6 +157,7 @@ public class CollectController : Controller await _eventService.LogCipherEventsAsync(eventsBatch); } } + return new OkResult(); } } diff --git a/src/Events/Program.cs b/src/Events/Program.cs index 967e94ed83..1a00549005 100644 --- a/src/Events/Program.cs +++ b/src/Events/Program.cs @@ -12,26 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.EventsSettings.IdentityToken; - } - - if (e.Properties.TryGetValue("RequestPath", out var requestPath) && - !string.IsNullOrWhiteSpace(requestPath?.ToString()) && - (context.Contains(".Server.Kestrel") || context.Contains(".Core.IISHttpServer"))) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.EventsSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Events/Startup.cs b/src/Events/Startup.cs index cfe177aa2c..75301cf08c 100644 --- a/src/Events/Startup.cs +++ b/src/Events/Startup.cs @@ -84,17 +84,16 @@ public class Startup services.AddHostedService(); } + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddRabbitMqListeners(globalSettings); } public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Events/appsettings.json b/src/Events/appsettings.json index e72b978f2f..41637c8549 100644 --- a/src/Events/appsettings.json +++ b/src/Events/appsettings.json @@ -14,9 +14,6 @@ "events": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "amazon": { "accessKeyId": "SECRET", "accessKeySecret": "SECRET", diff --git a/src/EventsProcessor/AzureQueueHostedService.cs b/src/EventsProcessor/AzureQueueHostedService.cs index c6f5afbfdd..c4c02e32d2 100644 --- a/src/EventsProcessor/AzureQueueHostedService.cs +++ b/src/EventsProcessor/AzureQueueHostedService.cs @@ -6,6 +6,7 @@ using Azure.Storage.Queues; using Bit.Core; using Bit.Core.Models.Data; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; namespace Bit.EventsProcessor; @@ -13,7 +14,7 @@ namespace Bit.EventsProcessor; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IConfiguration _configuration; + private readonly GlobalSettings _globalSettings; private Task _executingTask; private CancellationTokenSource _cts; @@ -22,10 +23,10 @@ public class AzureQueueHostedService : IHostedService, IDisposable public AzureQueueHostedService( ILogger logger, - IConfiguration configuration) + GlobalSettings globalSettings) { _logger = logger; - _configuration = configuration; + _globalSettings = globalSettings; } public Task StartAsync(CancellationToken cancellationToken) @@ -56,15 +57,18 @@ public class AzureQueueHostedService : IHostedService, IDisposable private async Task ExecuteAsync(CancellationToken cancellationToken) { - var storageConnectionString = _configuration["azureStorageConnectionString"]; - if (string.IsNullOrWhiteSpace(storageConnectionString)) + var storageConnectionString = _globalSettings.Events.ConnectionString; + var queueName = _globalSettings.Events.QueueName; + if (string.IsNullOrWhiteSpace(storageConnectionString) || + string.IsNullOrWhiteSpace(queueName)) { + _logger.LogInformation("Azure Queue Hosted Service is disabled. Missing connection string or queue name."); return; } var repo = new Core.Repositories.TableStorage.EventRepository(storageConnectionString); _eventWriteService = new RepositoryEventWriteService(repo); - _queueClient = new QueueClient(storageConnectionString, "event"); + _queueClient = new QueueClient(storageConnectionString, queueName); while (!cancellationToken.IsCancellationRequested) { diff --git a/src/EventsProcessor/Program.cs b/src/EventsProcessor/Program.cs index 9b7a31e6f4..e4f4ac90d1 100644 --- a/src/EventsProcessor/Program.cs +++ b/src/EventsProcessor/Program.cs @@ -11,9 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => e.Level >= globalSettings.MinLogLevel.EventsProcessorSettings.Default)); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index 67676a8afc..888dda43a1 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -1,5 +1,4 @@ using System.Globalization; -using Bit.Core.Settings; using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Microsoft.IdentityModel.Logging; @@ -32,19 +31,15 @@ public class Startup // Repositories services.AddDatabaseRepositories(globalSettings); - // Hosted Services + // Add event integration services + services.AddDistributedCache(globalSettings); services.AddAzureServiceBusListeners(globalSettings); services.AddHostedService(); } - public void Configure( - IApplicationBuilder app, - IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, - GlobalSettings globalSettings) + public void Configure(IApplicationBuilder app) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); app.UseRouting(); diff --git a/src/Icons/Program.cs b/src/Icons/Program.cs index 237096b0b1..80c1b5728e 100644 --- a/src/Icons/Program.cs +++ b/src/Icons/Program.cs @@ -11,9 +11,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => e.Level >= globalSettings.MinLogLevel.IconsSettings.Default)); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Icons/Startup.cs b/src/Icons/Startup.cs index 2602dd6264..5d9b5e5a30 100644 --- a/src/Icons/Startup.cs +++ b/src/Icons/Startup.cs @@ -60,11 +60,8 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index cc146800af..b7d4342c1b 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -109,8 +109,12 @@ public class AccountsController : Controller [HttpPost("register/send-verification-email")] public async Task PostRegisterSendVerificationEmail([FromBody] RegisterSendVerificationEmailRequestModel model) { + // Only pass fromMarketing if the feature flag is enabled + var isMarketingFeatureEnabled = _featureService.IsEnabled(FeatureFlagKeys.MarketingInitiatedPremiumFlow); + var fromMarketing = isMarketingFeatureEnabled ? model.FromMarketing : null; + var token = await _sendVerificationEmailForRegistrationCommand.Run(model.Email, model.Name, - model.ReceiveMarketingEmails); + model.ReceiveMarketingEmails, fromMarketing); if (token != null) { @@ -195,16 +199,35 @@ public class AccountsController : Controller throw new BadRequestException(ModelState); } - // Moved from API, If you modify this endpoint, please update API as well. Self hosted installs still use the API endpoints. [HttpPost("prelogin")] - public async Task PostPrelogin([FromBody] PreloginRequestModel model) + [Obsolete("Migrating to use a more descriptive endpoint that would support different types of prelogins. " + + "Use prelogin/password instead. This endpoint has no EOL at the time of writing.")] + public async Task PostPrelogin([FromBody] PasswordPreloginRequestModel model) + { + // Same as PostPasswordPrelogin to maintain compatibility. Do not make changes in this function body, + // only make changes in MakePasswordPreloginCall + return await MakePasswordPreloginCall(model); + } + + // There are two functions done this way because the open api docs that get generated in our build pipeline + // cannot handle two of the same post attributes on the same function call. That is why there is a + // PostPrelogin and the more appropriate PostPasswordPrelogin. + [HttpPost("prelogin/password")] + public async Task PostPasswordPrelogin([FromBody] PasswordPreloginRequestModel model) + { + // Same as PostPrelogin to maintain backwards compatibility. Do not make changes in this function body, + // only make changes in MakePasswordPreloginCall + return await MakePasswordPreloginCall(model); + } + + private async Task MakePasswordPreloginCall(PasswordPreloginRequestModel model) { var kdfInformation = await _userRepository.GetKdfInformationByEmailAsync(model.Email); if (kdfInformation == null) { kdfInformation = GetDefaultKdf(model.Email); } - return new PreloginResponseModel(kdfInformation); + return new PasswordPreloginResponseModel(kdfInformation, model.Email); } [HttpGet("webauthn/assertion-options")] @@ -228,19 +251,17 @@ public class AccountsController : Controller { return _defaultKdfResults[0]; } - else - { - // Compute the HMAC hash of the email - var hmacMessage = Encoding.UTF8.GetBytes(email.Trim().ToLowerInvariant()); - using var hmac = new System.Security.Cryptography.HMACSHA256(_defaultKdfHmacKey); - var hmacHash = hmac.ComputeHash(hmacMessage); - // Convert the hash to a number - var hashHex = BitConverter.ToString(hmacHash).Replace("-", string.Empty).ToLowerInvariant(); - var hashFirst8Bytes = hashHex.Substring(0, 16); - var hashNumber = long.Parse(hashFirst8Bytes, System.Globalization.NumberStyles.HexNumber); - // Find the default KDF value for this hash number - var hashIndex = (int)(Math.Abs(hashNumber) % _defaultKdfResults.Count); - return _defaultKdfResults[hashIndex]; - } + + // Compute the HMAC hash of the email + var hmacMessage = Encoding.UTF8.GetBytes(email.Trim().ToLowerInvariant()); + using var hmac = new System.Security.Cryptography.HMACSHA256(_defaultKdfHmacKey); + var hmacHash = hmac.ComputeHash(hmacMessage); + // Convert the hash to a number + var hashHex = BitConverter.ToString(hmacHash).Replace("-", string.Empty).ToLowerInvariant(); + var hashFirst8Bytes = hashHex.Substring(0, 16); + var hashNumber = long.Parse(hashFirst8Bytes, System.Globalization.NumberStyles.HexNumber); + // Find the default KDF value for this hash number + var hashIndex = (int)(Math.Abs(hashNumber) % _defaultKdfResults.Count); + return _defaultKdfResults[hashIndex]; } } diff --git a/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs new file mode 100644 index 0000000000..4787125045 --- /dev/null +++ b/src/Identity/IdentityServer/Constants/RequestValidationConstants.cs @@ -0,0 +1,30 @@ +namespace Bit.Identity.IdentityServer.RequestValidationConstants; + +public static class CustomResponseConstants +{ + public static class ResponseKeys + { + /// + /// Identifies the error model returned in the custom response when an error occurs. + /// + public static string ErrorModel => "ErrorModel"; + /// + /// This Key is used when a user is in a single organization that requires SSO authentication. The identifier + /// is used by the client to speed the redirection to the correct IdP for the user's organization. + /// + public static string SsoOrganizationIdentifier => "SsoOrganizationIdentifier"; + } +} + +public static class SsoConstants +{ + /// + /// These are messages and errors we return when SSO Validation is unsuccessful + /// + public static class RequestErrors + { + public static string SsoRequired => "sso_required"; + public static string SsoRequiredDescription => "Sso authentication is required."; + public static string SsoTwoFactorRecoveryDescription => "Two-factor recovery has been performed. SSO authentication is required."; + } +} diff --git a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs index a709a47cb2..e16c8ad695 100644 --- a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs @@ -27,6 +27,12 @@ public class CustomValidatorRequestContext ///
    public bool TwoFactorRequired { get; set; } = false; /// + /// Whether the user has requested recovery of their 2FA methods using their one-time + /// recovery code. + /// + /// + public bool TwoFactorRecoveryRequested { get; set; } = false; + /// /// This communicates whether or not SSO is required for the user to authenticate. /// public bool SsoRequired { get; set; } = false; @@ -42,10 +48,13 @@ public class CustomValidatorRequestContext /// This will be null if the authentication request is successful. ///
    public Dictionary CustomResponse { get; set; } - /// /// A validated auth request /// /// public AuthRequest ValidatedAuthRequest { get; set; } + /// + /// Whether the user has requested a Remember Me token for their current device. + /// + public bool RememberMeRequested { get; set; } = false; } diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index e57ed1c85f..e07446d49f 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below + #nullable disable using System.Security.Claims; @@ -14,6 +15,8 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Api; using Bit.Core.Models.Api.Response; using Bit.Core.Repositories; @@ -31,6 +34,7 @@ public abstract class BaseRequestValidator where T : class private readonly IEventService _eventService; private readonly IDeviceValidator _deviceValidator; private readonly ITwoFactorAuthenticationValidator _twoFactorAuthenticationValidator; + private readonly ISsoRequestValidator _ssoRequestValidator; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ILogger _logger; private readonly GlobalSettings _globalSettings; @@ -40,11 +44,12 @@ public abstract class BaseRequestValidator where T : class protected ICurrentContext CurrentContext { get; } protected IPolicyService PolicyService { get; } - protected IFeatureService FeatureService { get; } + protected IFeatureService _featureService { get; } protected ISsoConfigRepository SsoConfigRepository { get; } protected IUserService _userService { get; } protected IUserDecryptionOptionsBuilder UserDecryptionOptionsBuilder { get; } protected IPolicyRequirementQuery PolicyRequirementQuery { get; } + protected IUserAccountKeysQuery _accountKeysQuery { get; } public BaseRequestValidator( UserManager userManager, @@ -52,6 +57,7 @@ public abstract class BaseRequestValidator where T : class IEventService eventService, IDeviceValidator deviceValidator, ITwoFactorAuthenticationValidator twoFactorAuthenticationValidator, + ISsoRequestValidator ssoRequestValidator, IOrganizationUserRepository organizationUserRepository, ILogger logger, ICurrentContext currentContext, @@ -63,150 +69,44 @@ public abstract class BaseRequestValidator where T : class IUserDecryptionOptionsBuilder userDecryptionOptionsBuilder, IPolicyRequirementQuery policyRequirementQuery, IAuthRequestRepository authRequestRepository, - IMailService mailService - ) + IMailService mailService, + IUserAccountKeysQuery userAccountKeysQuery + ) { _userManager = userManager; _userService = userService; _eventService = eventService; _deviceValidator = deviceValidator; _twoFactorAuthenticationValidator = twoFactorAuthenticationValidator; + _ssoRequestValidator = ssoRequestValidator; _organizationUserRepository = organizationUserRepository; _logger = logger; CurrentContext = currentContext; _globalSettings = globalSettings; PolicyService = policyService; _userRepository = userRepository; - FeatureService = featureService; + _featureService = featureService; SsoConfigRepository = ssoConfigRepository; UserDecryptionOptionsBuilder = userDecryptionOptionsBuilder; PolicyRequirementQuery = policyRequirementQuery; _authRequestRepository = authRequestRepository; _mailService = mailService; + _accountKeysQuery = userAccountKeysQuery; } protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // 1. We need to check if the user's master password hash is correct. - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) + var validators = DetermineValidationOrder(context, request, validatorContext); + var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); + if (!allValidationSchemesSuccessful) { - await UpdateFailedAuthDetailsAsync(user); - - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + // Each validation task is responsible for setting its own non-success status, if applicable. return; } - // 2. Decide if this user belongs to an organization that requires SSO. - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) - { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; - } - - // 3. Check if 2FA is required. - (validatorContext.TwoFactorRequired, var twoFactorOrganization) = - await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); - - // This flag is used to determine if the user wants a rememberMe token sent when - // authentication is successful. - var returnRememberMeToken = false; - - if (validatorContext.TwoFactorRequired) - { - var twoFactorToken = request.Raw["TwoFactorToken"]; - var twoFactorProvider = request.Raw["TwoFactorProvider"]; - var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - // 3a. Response for 2FA required and not provided state. - if (!validTwoFactorRequest || - !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - if (resultDict == null) - { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); - return; - } - - // Include Master Password Policy in 2FA response. - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - return; - } - - var twoFactorTokenValid = - await _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); - - // 3b. Response for 2FA required but request is not valid or remember token expired state. - if (!twoFactorTokenValid) - { - // The remember me token has expired. - if (twoFactorProviderType == TwoFactorProviderType.Remember) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - - // Include Master Password Policy in 2FA response - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - } - else - { - await SendFailedTwoFactorEmail(user, twoFactorProviderType); - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - } - return; - } - - // 3c. When the 2FA authentication is successful, we can check if the user wants a - // rememberMe token. - var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; - // Check if the user wants a rememberMe token. - if (twoFactorRemember - // if the 2FA auth was rememberMe do not send another token. - && twoFactorProviderType != TwoFactorProviderType.Remember) - { - returnRememberMeToken = true; - } - } - - // 4. Check if the user is logging in from a new device. - var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); - if (!deviceValid) - { - SetValidationErrorResult(context, validatorContext); - await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); - return; - } - - // 5. Force legacy users to the web for migration. - if (UserService.IsLegacyUser(user) && request.ClientId != "web") - { - await FailAuthForLegacyUserAsync(user, context); - return; - } - - // TODO: PM-24324 - This should be its own validator at some point. - // 6. Auth request handling - if (validatorContext.ValidatedAuthRequest != null) - { - validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; - await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); - } - - await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); + await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, + validatorContext.RememberMeRequested); } protected async Task FailAuthForLegacyUserAsync(User user, T context) @@ -218,6 +118,322 @@ public abstract class BaseRequestValidator where T : class protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); + /// + /// Composer for validation schemes. + /// + /// The current request context. + /// + /// + /// A composed array of validation scheme delegates to evaluate in order. + private Func>[] DetermineValidationOrder(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (RecoveryCodeRequestForSsoRequiredUserScenario()) + { + // Support valid requests to recover 2FA (with account code) for users who require SSO + // by organization membership. + // This requires an evaluation of 2FA validity in front of SSO, and an opportunity for the 2FA + // validation to perform the recovery as part of scheme validation based on the request. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + else + { + // The typical validation scenario. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + + bool RecoveryCodeRequestForSsoRequiredUserScenario() + { + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var twoFactorToken = request.Raw["TwoFactorToken"]; + + // Both provider and token must be present; + // Validity of the token for a given provider will be evaluated by the TwoFactorAuthenticationValidator. + if (string.IsNullOrWhiteSpace(twoFactorProvider) || string.IsNullOrWhiteSpace(twoFactorToken)) + { + return false; + } + + if (!int.TryParse(twoFactorProvider, out var providerValue)) + { + return false; + } + + return providerValue == (int)TwoFactorProviderType.RecoveryCode; + } + } + + /// + /// Processes the validation schemes sequentially. + /// Each validator is responsible for setting error context responses on failure and adding itself to the + /// validatorContext's CompletedValidationSchemes (only) on success. + /// Failure of any scheme to validate will short-circuit the collection, causing the validation error to be + /// returned and further schemes to not be evaluated. + /// + /// The collection of validation schemes as composed in + /// true if all schemes validated successfully, false if any failed. + private static async Task ProcessValidatorsAsync(params Func>[] validators) + { + foreach (var validator in validators) + { + if (!await validator()) + { + return false; + } + } + + return true; + } + + /// + /// Validates the user's Master Password hash. + /// + /// The current request context. + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateMasterPasswordAsync(T context, CustomValidatorRequestContext validatorContext) + { + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (valid) + { + return true; + } + + await UpdateFailedAuthDetailsAsync(user); + + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return false; + } + + /// + /// Validates the user's organization-enforced Single Sign-on (SSO) requirement. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + /// + private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + // TODO: Clean up Feature Flag: Remove this if block: PM-28281 + if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) + { + validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); + if (!validatorContext.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (validatorContext.TwoFactorRequired && + validatorContext.TwoFactorRecoveryRequested) + { + SetSsoResult(context, + new Dictionary + { + { + "ErrorModel", + new ErrorResponseModel( + "Two-factor recovery has been performed. SSO authentication is required.") + } + }); + return false; + } + + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return false; + } + else + { + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) + { + return true; + } + + SetValidationErrorResult(context, validatorContext); + return ssoValid; + } + } + + /// + /// Validates the user's Multi-Factor Authentication (2FA) scheme. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateTwoFactorAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + (validatorContext.TwoFactorRequired, var twoFactorOrganization) = + await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(validatorContext.User, request); + + if (!validatorContext.TwoFactorRequired) + { + return true; + } + + var twoFactorToken = request.Raw["TwoFactorToken"]; + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + // 3a. Response for 2FA required and not provided state. + if (!validTwoFactorRequest || + !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + if (resultDict == null) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, validatorContext.User); + return false; + } + + // Include Master Password Policy in 2FA response. + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + return false; + } + + var twoFactorTokenValid = + await _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(validatorContext.User, twoFactorOrganization, twoFactorProviderType, + twoFactorToken); + + // 3b. Response for 2FA required but request is not valid or remember token expired state. + if (!twoFactorTokenValid) + { + // The remember me token has expired. + if (twoFactorProviderType == TwoFactorProviderType.Remember) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + + // Include Master Password Policy in 2FA response + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + } + else + { + await SendFailedTwoFactorEmail(validatorContext.User, twoFactorProviderType); + await UpdateFailedAuthDetailsAsync(validatorContext.User); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, + validatorContext.User); + } + + return false; + } + + // 3c. Given a valid token and a successful two-factor verification, if the provider type is Recovery Code, + // recovery will have been performed as part of 2FA validation. This will be relevant for, e.g., SSO users + // who are requesting recovery, but who will still need to log in after 2FA recovery. + if (twoFactorProviderType == TwoFactorProviderType.RecoveryCode) + { + validatorContext.TwoFactorRecoveryRequested = true; + } + + // 3d. When the 2FA authentication is successful, we can check if the user wants a + // rememberMe token. + var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; + // Check if the user wants a rememberMe token. + if (twoFactorRemember + // if the 2FA auth was rememberMe do not send another token. + && twoFactorProviderType != TwoFactorProviderType.Remember) + { + validatorContext.RememberMeRequested = true; + } + + return true; + } + + /// + /// Validates whether the user is logging in from a known device. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateNewDeviceAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); + if (deviceValid) + { + return true; + } + + SetValidationErrorResult(context, validatorContext); + await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); + return false; + } + + /// + /// Validates whether the user should be denied access on a given non-Web client and sent to the Web client + /// for Legacy migration. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateLegacyMigrationAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (!UserService.IsLegacyUser(validatorContext.User) || request.ClientId == "web") + { + return true; + } + + await FailAuthForLegacyUserAsync(validatorContext.User, context); + return false; + } + + /// + /// Validates and updates the auth request's timestamp. + /// + /// + /// true on evaluation and/or completed update of the AuthRequest. + private async Task ValidateAuthRequestAsync(CustomValidatorRequestContext validatorContext) + { + // TODO: PM-24324 - This should be its own validator at some point. + if (validatorContext.ValidatedAuthRequest != null) + { + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); + } + + return true; + } /// /// Responsible for building the response to the client when the user has successfully authenticated. @@ -251,7 +467,7 @@ public abstract class BaseRequestValidator where T : class /// used to associate the failed login with a user /// void [Obsolete("Consider using SetValidationErrorResult to set the validation result, and LogFailedLoginEvent " + - "to log the failure.")] + "to log the failure.")] protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) { if (user != null) @@ -263,8 +479,8 @@ public abstract class BaseRequestValidator where T : class if (_globalSettings.SelfHosted) { _logger.LogWarning(Constants.BypassFiltersEventId, - string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", - $" {CurrentContext.IpAddress}")); + "Failed login attempt. Is2FARequest: {Is2FARequest} IpAddress: {IpAddress}", twoFactorRequest, + CurrentContext.IpAddress); } await Task.Delay(2000); // Delay for brute force. @@ -288,21 +504,26 @@ public abstract class BaseRequestValidator where T : class formattedMessage = string.Format("Failed login attempt. {0}", $" {CurrentContext.IpAddress}"); break; case EventType.User_FailedLogIn2fa: - formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", $" {CurrentContext.IpAddress}"); + formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", + $" {CurrentContext.IpAddress}"); break; default: formattedMessage = "Failed login attempt."; break; } - _logger.LogWarning(Constants.BypassFiltersEventId, formattedMessage); + + _logger.LogWarning(Constants.BypassFiltersEventId, "{FailedLoginMessage}", formattedMessage); } + await Task.Delay(2000); // Delay for brute force. } [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetSsoResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -313,6 +534,7 @@ public abstract class BaseRequestValidator where T : class /// The current grant or token context /// The modified request context containing material used to build the response object protected abstract void SetValidationErrorResult(T context, CustomValidatorRequestContext requestContext); + protected abstract Task SetSuccessResult(T context, User user, List claims, Dictionary customResponse); @@ -327,6 +549,8 @@ public abstract class BaseRequestValidator where T : class /// user trying to login /// magic string identifying the grant type requested /// true if sso required; false if not required or already in process + [Obsolete( + "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] private async Task RequireSsoLoginAsync(User user, string grantType) { if (grantType == "authorization_code" || grantType == "client_credentials") @@ -337,9 +561,9 @@ public abstract class BaseRequestValidator where T : class } // Check if user belongs to any organization with an active SSO policy - var ssoRequired = FeatureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired + .SsoRequired : await PolicyService.AnyPoliciesApplicableToUserAsync( user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); if (ssoRequired) @@ -379,10 +603,8 @@ public abstract class BaseRequestValidator where T : class private async Task SendFailedTwoFactorEmail(User user, TwoFactorProviderType failedAttemptType) { - if (FeatureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) - { - await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, CurrentContext.IpAddress); - } + await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, + CurrentContext.IpAddress); } private async Task GetMasterPasswordPolicyAsync(User user) @@ -412,16 +634,14 @@ public abstract class BaseRequestValidator where T : class // We need this because we check for changes in the stamp to determine if we need to invalidate token refresh requests, // in the `ProfileService.IsActiveAsync` method. // If we don't store the security stamp in the persisted grant, we won't have the previous value to compare against. - var claims = new List - { - new Claim(Claims.SecurityStamp, user.SecurityStamp) - }; + var claims = new List { new Claim(Claims.SecurityStamp, user.SecurityStamp) }; if (device != null) { claims.Add(new Claim(Claims.Device, device.Identifier)); claims.Add(new Claim(Claims.DeviceType, device.Type.ToString())); } + return claims; } @@ -433,16 +653,21 @@ public abstract class BaseRequestValidator where T : class /// The current request context. /// The device used for authentication. /// Whether to send a 2FA remember token. - private async Task> BuildCustomResponse(User user, T context, Device device, bool sendRememberToken) + private async Task> BuildCustomResponse(User user, T context, Device device, + bool sendRememberToken) { var customResponse = new Dictionary(); if (!string.IsNullOrWhiteSpace(user.PrivateKey)) { + // PrivateKey usage is now deprecated in favor of AccountKeys customResponse.Add("PrivateKey", user.PrivateKey); + var accountKeys = await _accountKeysQuery.Run(user); + customResponse.Add("AccountKeys", new PrivateKeysResponseModel(accountKeys)); } if (!string.IsNullOrWhiteSpace(user.Key)) { + // Key is deprecated in favor of UserDecryptionOptions.MasterPasswordUnlock.MasterKeyEncryptedUserKey customResponse.Add("Key", user.Key); } @@ -453,7 +678,8 @@ public abstract class BaseRequestValidator where T : class customResponse.Add("KdfIterations", user.KdfIterations); customResponse.Add("KdfMemory", user.KdfMemory); customResponse.Add("KdfParallelism", user.KdfParallelism); - customResponse.Add("UserDecryptionOptions", await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); + customResponse.Add("UserDecryptionOptions", + await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); if (sendRememberToken) { @@ -461,6 +687,7 @@ public abstract class BaseRequestValidator where T : class CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); customResponse.Add("TwoFactorToken", token); } + return customResponse; } @@ -468,7 +695,8 @@ public abstract class BaseRequestValidator where T : class /// /// Used to create a list of all possible ways the newly authenticated user can decrypt their vault contents /// - private async Task CreateUserDecryptionOptionsAsync(User user, Device device, ClaimsPrincipal subject) + private async Task CreateUserDecryptionOptionsAsync(User user, Device device, + ClaimsPrincipal subject) { var ssoConfig = await GetSsoConfigurationDataAsync(subject); return await UserDecryptionOptionsBuilder diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 1495973b80..38a4813ecd 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -8,6 +8,7 @@ using Bit.Core.Auth.Models.Api.Response; using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.Services; @@ -35,6 +36,7 @@ public class CustomTokenRequestValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -47,13 +49,15 @@ public class CustomTokenRequestValidator : BaseRequestValidator +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public interface ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. Sets the error state in the if SSO is required. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context); +} diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index 17592cc0c1..ea2c021f63 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -8,6 +8,7 @@ using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -30,6 +31,7 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -41,13 +43,15 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator otpTokenProvider, IMailService mailService) : ISendAuthenticationMethodValidator { @@ -60,11 +62,20 @@ public class SendEmailOtpRequestValidator( { return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.OtpGenerationFailed); } - - await mailService.SendSendEmailOtpEmailAsync( - email, - token, - string.Format(SendAccessConstants.OtpEmail.Subject, token)); + if (featureService.IsEnabled(FeatureFlagKeys.MJMLBasedEmailTemplates)) + { + await mailService.SendSendEmailOtpEmailv2Async( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + } + else + { + await mailService.SendSendEmailOtpEmailAsync( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + } return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent); } diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/readme.md b/src/Identity/IdentityServer/RequestValidators/SendAccess/readme.md index afab13a156..2a6ea66857 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/readme.md +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/readme.md @@ -1,17 +1,15 @@ -Send Access Request Validation -=== +# Send Access Request Validation This feature supports the ability of Tools to require specific claims for access to sends. In order to access Send data a user must meet the requirements laid out in these request validators. -# ***Important: String Constants*** - -The string constants contained herein are used in conjunction with the Auth module in the SDK. Any change to these string values _must_ be intentional and _must_ have a corresponding change in the SDK. +> [!IMPORTANT] +> The string constants contained herein are used in conjunction with the Auth module in the SDK. Any change to these string values _must_ be intentional and _must_ have a corresponding change in the SDK. There is snapshot testing that will fail if the strings change to help detect unintended changes to the string constants. -# Custom Claims +## Custom Claims Send access tokens contain custom claims specific to the Send the Send grant type. @@ -19,41 +17,41 @@ Send access tokens contain custom claims specific to the Send the Send grant typ 1. `send_email` - only set when the Send requires `EmailOtp` authentication type. 1. `type` - this will always be `Send` -# Authentication methods +## Authentication methods -## `NeverAuthenticate` +### `NeverAuthenticate` For a Send to be in this state two things can be true: 1. The Send has been modified and no longer allows access. 2. The Send does not exist. -## `NotAuthenticated` +### `NotAuthenticated` In this scenario the Send is not protected by any added authentication or authorization and the access token is issued to the requesting user. -## `ResourcePassword` +### `ResourcePassword` In this scenario the Send is password protected and a user must supply the correct password hash to be issued an access token. -## `EmailOtp` +### `EmailOtp` In this scenario the Send is only accessible to owners of specific email addresses. The user must submit a correct email. Once the email has been entered then ownership of the email must be established via OTP. The Otp is sent to the aforementioned email and must be supplied, along with the email, to be issued an access token. -# Send Access Request Validation +## Send Access Request Validation -## Required Parameters +### Required Parameters -### All Requests +#### All Requests - `send_id` - Base64 URL-encoded GUID of the send being accessed -### Password Protected Sends +#### Password Protected Sends - `password_hash_b64` - client hashed Base64-encoded password. -### Email OTP Protected Sends +#### Email OTP Protected Sends - `email` - Email address associated with the send - `otp` - One-time password (optional - if missing, OTP is generated and sent) -## Error Responses +### Error Responses All errors include a custom response field: ```json @@ -62,5 +60,4 @@ All errors include a custom response field: "error_description": "Human readable description", "send_access_error_type": "specific_error_code" } -``` - +``` \ No newline at end of file diff --git a/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs new file mode 100644 index 0000000000..145ecc8737 --- /dev/null +++ b/src/Identity/IdentityServer/RequestValidators/SsoRequestValidator.cs @@ -0,0 +1,124 @@ +using Bit.Core; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Services; +using Bit.Identity.IdentityServer.RequestValidationConstants; +using Duende.IdentityModel; +using Duende.IdentityServer.Validation; + +namespace Bit.Identity.IdentityServer.RequestValidators; + +/// +/// Validates whether a user is required to authenticate via SSO based on organization policies. +/// +public class SsoRequestValidator( + IPolicyService _policyService, + IFeatureService _featureService, + IUserSsoOrganizationIdentifierQuery _userSsoOrganizationIdentifierQuery, + IPolicyRequirementQuery _policyRequirementQuery) : ISsoRequestValidator +{ + /// + /// Validates the SSO requirement for a user attempting to authenticate. + /// Sets context.SsoRequired to indicate whether SSO is required. + /// If SSO is required, sets the validation error result and custom response in the context. + /// + /// The user attempting to authenticate. + /// The token request containing grant type and other authentication details. + /// The validator context to be updated with SSO requirement status and error results if applicable. + /// true if the user can proceed with authentication; false if SSO is required and the user must be redirected to SSO flow. + public async Task ValidateAsync(User user, ValidatedTokenRequest request, CustomValidatorRequestContext context) + { + context.SsoRequired = await RequireSsoAuthenticationAsync(user, request.GrantType); + + if (!context.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (context.TwoFactorRequired && context.TwoFactorRecoveryRequested) + { + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription); + return false; + } + + await SetContextCustomResponseSsoErrorAsync(context, SsoConstants.RequestErrors.SsoRequiredDescription); + return false; + } + + /// + /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are + /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. + /// If the GrantType is authorization_code or client_credentials we know the user is trying to log in + /// using the SSO flow so they are allowed to continue. + /// + /// user trying to log in + /// magic string identifying the grant type requested + /// true if sso required; false if not required or already in process + private async Task RequireSsoAuthenticationAsync(User user, string grantType) + { + if (grantType == OidcConstants.GrantTypes.AuthorizationCode || + grantType == OidcConstants.GrantTypes.ClientCredentials) + { + // SSO is not required for users already using SSO to authenticate which uses the authorization_code grant type, + // or logging-in via API key which is the client_credentials grant type. + // Allow user to continue request validation + return false; + } + + // Check if user belongs to any organization with an active SSO policy + var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) + ? (await _policyRequirementQuery.GetAsync(user.Id)) + .SsoRequired + : await _policyService.AnyPoliciesApplicableToUserAsync( + user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); + + if (ssoRequired) + { + return true; + } + + // Default - SSO is not required + return false; + } + + /// + /// Sets the customResponse in the context with the error result for the SSO validation failure. + /// + /// The validator context to update with error details. + /// The error message to return to the client. + private async Task SetContextCustomResponseSsoErrorAsync(CustomValidatorRequestContext context, string errorMessage) + { + var ssoOrganizationIdentifier = await _userSsoOrganizationIdentifierQuery.GetSsoOrganizationIdentifierAsync(context.User.Id); + + context.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = OidcConstants.TokenErrors.InvalidGrant, + ErrorDescription = errorMessage + }; + + context.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(errorMessage) } + }; + + // Include organization identifier in the response if available + if (!string.IsNullOrEmpty(ssoOrganizationIdentifier)) + { + context.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier] = ssoOrganizationIdentifier; + } + } +} diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index e679c48433..e4cd60827e 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -12,6 +12,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.WebAuthnLogin; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -37,6 +38,7 @@ public class WebAuthnGrantValidator : BaseRequestValidator logger, ICurrentContext currentContext, @@ -50,13 +52,15 @@ public class WebAuthnGrantValidator : BaseRequestValidator { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains(typeof(IpRateLimitMiddleware).FullName)) - { - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.IpRateLimit; - } - - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.IdentityToken; - } - - return e.Level >= globalSettings.MinLogLevel.IdentitySettings.Default; - })); - }); + }) + .AddSerilogFileLogging(); } } diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index 8da31d87d6..5dc443a73c 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -170,14 +170,11 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings, ILogger logger) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); - // Add general security headers app.UseMiddleware(); @@ -240,6 +237,6 @@ public class Startup app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/src/Identity/Utilities/ServiceCollectionExtensions.cs b/src/Identity/Utilities/ServiceCollectionExtensions.cs index e9056d030e..7e64975c95 100644 --- a/src/Identity/Utilities/ServiceCollectionExtensions.cs +++ b/src/Identity/Utilities/ServiceCollectionExtensions.cs @@ -26,6 +26,7 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); services.AddTransient(); services.AddTransient, SendPasswordRequestValidator>(); services.AddTransient, SendEmailOtpRequestValidator>(); diff --git a/src/Identity/appsettings.json b/src/Identity/appsettings.json index 16c3efe46b..c21d2dff3b 100644 --- a/src/Identity/appsettings.json +++ b/src/Identity/appsettings.json @@ -27,9 +27,6 @@ "events": { "connectionString": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "notificationHub": { "connectionString": "SECRET", "hubName": "SECRET" diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs index 5a743ba028..2be33e8846 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -218,6 +218,8 @@ public static class BulkResourceCreationService ciphersTable.Columns.Add(revisionDateColumn); var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); ciphersTable.Columns.Add(deletedDateColumn); + var archivedDateColumn = new DataColumn(nameof(c.ArchivedDate), typeof(DateTime)); + ciphersTable.Columns.Add(archivedDateColumn); var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); ciphersTable.Columns.Add(repromptColumn); var keyColummn = new DataColumn(nameof(c.Key), typeof(string)); @@ -247,6 +249,7 @@ public static class BulkResourceCreationService row[creationDateColumn] = cipher.CreationDate; row[revisionDateColumn] = cipher.RevisionDate; row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[archivedDateColumn] = cipher.ArchivedDate.HasValue ? cipher.ArchivedDate : DBNull.Value; row[repromptColumn] = cipher.Reprompt.HasValue ? cipher.Reprompt.Value : DBNull.Value; row[keyColummn] = cipher.Key; diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index 5f389ae56d..bd670347a9 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -2,6 +2,7 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.Utilities.DebuggingInstruments; using Bit.Core.Entities; @@ -15,8 +16,6 @@ using Dapper; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Logging; -#nullable enable - namespace Bit.Infrastructure.Dapper.Repositories; public class OrganizationUserRepository : Repository, IOrganizationUserRepository @@ -626,7 +625,11 @@ public class OrganizationUserRepository : Repository, IO await connection.ExecuteAsync( "[dbo].[OrganizationUser_SetStatusForUsersByGuidIdArray]", - new { OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP(), Status = OrganizationUserStatusType.Revoked }, + new + { + OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP(), + Status = OrganizationUserStatusType.Revoked + }, commandType: CommandType.StoredProcedure); } @@ -672,4 +675,38 @@ public class OrganizationUserRepository : Repository, IO }, commandType: CommandType.StoredProcedure); } + + public async Task ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) + { + await using var connection = new SqlConnection(_marsConnectionString); + + var rowCount = await connection.ExecuteScalarAsync( + $"[{Schema}].[OrganizationUser_ConfirmById]", + new + { + Id = organizationUserToConfirm.OrganizationUserId, + UserId = organizationUserToConfirm.UserId, + RevisionDate = DateTime.UtcNow.Date, + Key = organizationUserToConfirm.Key + }); + + return rowCount > 0; + } + + public async Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationUserUserDetails_ReadByOrganizationIdUserId]", + new + { + OrganizationId = organizationId, + UserId = userId + }, + commandType: CommandType.StoredProcedure); + + return result; + } + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs index 83d5ef6a70..865c4f8e5c 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs @@ -61,19 +61,6 @@ public class PolicyRepository : Repository, IPolicyRepository } } - public async Task> GetPolicyDetailsByUserId(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[PolicyDetails_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - public async Task> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable userIds, PolicyType type) { await using var connection = new SqlConnection(ConnectionString); diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs index 467857612f..c05ff040e5 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/ProviderUserRepository.cs @@ -61,6 +61,18 @@ public class ProviderUserRepository : Repository, IProviderU } } + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + await using var connection = new SqlConnection(ConnectionString); + + var results = await connection.QueryAsync( + "[dbo].[ProviderUser_ReadManyByManyUserIds]", + new { UserIds = userIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) { using (var connection = new SqlConnection(ConnectionString)) diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index 35fc094973..e3ee82270f 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -15,6 +15,7 @@ using Bit.Infrastructure.Dapper.AdminConsole.Repositories; using Bit.Infrastructure.Dapper.Auth.Repositories; using Bit.Infrastructure.Dapper.Billing.Repositories; using Bit.Infrastructure.Dapper.Dirt; +using Bit.Infrastructure.Dapper.Dirt.Repositories; using Bit.Infrastructure.Dapper.KeyManagement.Repositories; using Bit.Infrastructure.Dapper.NotificationCenter.Repositories; using Bit.Infrastructure.Dapper.Platform; @@ -71,6 +72,7 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs index 3d001cce92..c704a208d1 100644 --- a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs +++ b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs @@ -4,6 +4,7 @@ using System.Data; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.Dapper.Repositories; @@ -173,4 +174,31 @@ public class OrganizationReportRepository : Repository commandType: CommandType.StoredProcedure); } } + + public async Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using var connection = new SqlConnection(ConnectionString); + var parameters = new + { + Id = reportId, + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }; + + await connection.ExecuteAsync( + $"[{Schema}].[OrganizationReport_UpdateMetrics]", + parameters, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs b/src/Infrastructure.Dapper/Dirt/Repositories/EventRepository.cs similarity index 100% rename from src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs rename to src/Infrastructure.Dapper/Dirt/Repositories/EventRepository.cs diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs similarity index 87% rename from src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs rename to src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs index 005e93c6aa..2b6b45f3c8 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs +++ b/src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -1,14 +1,15 @@ using System.Data; -using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.Dapper.Repositories; using Dapper; using Microsoft.Data.SqlClient; -namespace Bit.Infrastructure.Dapper.AdminConsole.Repositories; +namespace Bit.Infrastructure.Dapper.Dirt.Repositories; public class OrganizationIntegrationConfigurationRepository : Repository, IOrganizationIntegrationConfigurationRepository { @@ -20,10 +21,9 @@ public class OrganizationIntegrationConfigurationRepository : Repository> GetConfigurationDetailsAsync( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) + public async Task> + GetManyByEventTypeOrganizationIdIntegrationType(EventType eventType, Guid organizationId, + IntegrationType integrationType) { using (var connection = new SqlConnection(ConnectionString)) { diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationRepository.cs similarity index 60% rename from src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs rename to src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationRepository.cs index ece9697a31..a094bbc669 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs +++ b/src/Infrastructure.Dapper/Dirt/Repositories/OrganizationIntegrationRepository.cs @@ -1,11 +1,12 @@ using System.Data; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Repositories; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Repositories; using Bit.Core.Settings; +using Bit.Infrastructure.Dapper.Repositories; using Dapper; using Microsoft.Data.SqlClient; -namespace Bit.Infrastructure.Dapper.Repositories; +namespace Bit.Infrastructure.Dapper.Dirt.Repositories; public class OrganizationIntegrationRepository : Repository, IOrganizationIntegrationRepository { @@ -29,4 +30,17 @@ public class OrganizationIntegrationRepository : Repository GetByTeamsConfigurationTenantIdTeamId(string tenantId, string teamId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId]", + new { TenantId = tenantId, TeamId = teamId }, + commandType: CommandType.StoredProcedure); + + return result; + } + } } diff --git a/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..5dcc2943b8 --- /dev/null +++ b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs @@ -0,0 +1,79 @@ +using System.Data; +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Bit.Infrastructure.Dapper.Repositories; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.KeyManagement.Repositories; + +public class UserSignatureKeyPairRepository : Repository, IUserSignatureKeyPairRepository +{ + public UserSignatureKeyPairRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { + } + + public UserSignatureKeyPairRepository(string connectionString, string readOnlyConnectionString) : base( + connectionString, readOnlyConnectionString) + { + } + + public async Task GetByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + return (await connection.QuerySingleOrDefaultAsync( + "[dbo].[UserSignatureKeyPair_ReadByUserId]", + new + { + UserId = userId + }, + commandType: CommandType.StoredProcedure))?.ToSignatureKeyPairData(); + } + } + + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signingKeys) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + await connection.QueryAsync( + "[dbo].[UserSignatureKeyPair_SetForRotation]", + new + { + Id = CoreHelpers.GenerateComb(), + UserId = userId, + SignatureAlgorithm = (byte)signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + }; + } + + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signingKeys) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + await connection.QueryAsync( + "[dbo].[UserSignatureKeyPair_UpdateForRotation]", + new + { + UserId = grantorId, + SignatureAlgorithm = (byte)signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + RevisionDate = DateTime.UtcNow + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + }; + } +} diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index c2a59f75aa..9985b41d56 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -226,7 +226,6 @@ public class CollectionRepository : Repository, ICollectionRep { obj.SetNewId(); - var objWithGroupsAndUsers = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj))!; objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); @@ -243,18 +242,52 @@ public class CollectionRepository : Repository, ICollectionRep public async Task ReplaceAsync(Collection obj, IEnumerable? groups, IEnumerable? users) { - var objWithGroupsAndUsers = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj))!; - - objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); - objWithGroupsAndUsers.Users = users != null ? users.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); - - using (var connection = new SqlConnection(ConnectionString)) + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + await using var transaction = await connection.BeginTransactionAsync(); + try { - var results = await connection.ExecuteAsync( - $"[{Schema}].[Collection_UpdateWithGroupsAndUsers]", - objWithGroupsAndUsers, - commandType: CommandType.StoredProcedure); + if (groups == null && users == null) + { + await connection.ExecuteAsync( + $"[{Schema}].[Collection_Update]", + obj, + commandType: CommandType.StoredProcedure, + transaction: transaction); + } + else if (groups != null && users == null) + { + await connection.ExecuteAsync( + $"[{Schema}].[Collection_UpdateWithGroups]", + new CollectionWithGroups(obj, groups), + commandType: CommandType.StoredProcedure, + transaction: transaction); + } + else if (groups == null && users != null) + { + await connection.ExecuteAsync( + $"[{Schema}].[Collection_UpdateWithUsers]", + new CollectionWithUsers(obj, users), + commandType: CommandType.StoredProcedure, + transaction: transaction); + } + else if (groups != null && users != null) + { + await connection.ExecuteAsync( + $"[{Schema}].[Collection_UpdateWithGroupsAndUsers]", + new CollectionWithGroupsAndUsers(obj, groups, users), + commandType: CommandType.StoredProcedure, + transaction: transaction); + } + + await transaction.CommitAsync(); } + catch + { + await transaction.RollbackAsync(); + throw; + } + } public async Task DeleteManyAsync(IEnumerable collectionIds) @@ -424,9 +457,70 @@ public class CollectionRepository : Repository, ICollectionRep public class CollectionWithGroupsAndUsers : Collection { + public CollectionWithGroupsAndUsers() { } + + public CollectionWithGroupsAndUsers(Collection collection, + IEnumerable groups, + IEnumerable users) + { + Id = collection.Id; + Name = collection.Name; + OrganizationId = collection.OrganizationId; + CreationDate = collection.CreationDate; + RevisionDate = collection.RevisionDate; + Type = collection.Type; + ExternalId = collection.ExternalId; + DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail; + Groups = groups.ToArrayTVP(); + Users = users.ToArrayTVP(); + } + [DisallowNull] public DataTable? Groups { get; set; } [DisallowNull] public DataTable? Users { get; set; } } + + public class CollectionWithGroups : Collection + { + public CollectionWithGroups() { } + + public CollectionWithGroups(Collection collection, IEnumerable groups) + { + Id = collection.Id; + Name = collection.Name; + OrganizationId = collection.OrganizationId; + CreationDate = collection.CreationDate; + RevisionDate = collection.RevisionDate; + Type = collection.Type; + ExternalId = collection.ExternalId; + DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail; + Groups = groups.ToArrayTVP(); + } + + [DisallowNull] + public DataTable? Groups { get; set; } + } + + public class CollectionWithUsers : Collection + { + public CollectionWithUsers() { } + + public CollectionWithUsers(Collection collection, IEnumerable users) + { + + Id = collection.Id; + Name = collection.Name; + OrganizationId = collection.OrganizationId; + CreationDate = collection.CreationDate; + RevisionDate = collection.RevisionDate; + Type = collection.Type; + ExternalId = collection.ExternalId; + DefaultUserCollectionEmail = collection.DefaultUserCollectionEmail; + Users = users.ToArrayTVP(); + } + + [DisallowNull] + public DataTable? Users { get; set; } + } } diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs index 91cbc40ff6..a8171c286b 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationDomainRepository.cs @@ -148,4 +148,16 @@ public class OrganizationDomainRepository : Repository commandType: CommandType.StoredProcedure) > 0; } } + + public async Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null) + { + await using var connection = new SqlConnection(ConnectionString); + + var result = await connection.QueryFirstOrDefaultAsync( + $"[{Schema}].[OrganizationDomain_HasVerifiedDomainWithBlockPolicy]", + new { DomainName = domainName, ExcludeOrganizationId = excludeOrganizationId }, + commandType: CommandType.StoredProcedure); + + return result; + } } diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 6b11d64cda..571319e4c7 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -1,17 +1,19 @@ using System.Data; using System.Text.Json; using Bit.Core; +using Bit.Core.Billing.Premium.Models; using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Core.Utilities; using Dapper; using Microsoft.AspNetCore.DataProtection; using Microsoft.Data.SqlClient; -#nullable enable - namespace Bit.Infrastructure.Dapper.Repositories; public class UserRepository : Repository, IUserRepository @@ -288,6 +290,63 @@ public class UserRepository : Repository, IUserRepository UnprotectData(user); } + public async Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null) + { + if (!accountKeysData.IsV2Encryption()) + { + throw new ArgumentException("Provided account keys data is not valid V2 encryption data.", nameof(accountKeysData)); + } + + var timestamp = DateTime.UtcNow; + var signatureKeyPairId = CoreHelpers.GenerateComb(); + + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + + await using var transaction = connection.BeginTransaction(); + try + { + await connection.ExecuteAsync( + "[dbo].[User_UpdateAccountCryptographicState]", + new + { + Id = userId, + PublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.PublicKey, + PrivateKey = accountKeysData.PublicKeyEncryptionKeyPairData.WrappedPrivateKey, + SignedPublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.SignedPublicKey, + SecurityState = accountKeysData.SecurityStateData!.SecurityState, + SecurityVersion = accountKeysData.SecurityStateData!.SecurityVersion, + SignatureKeyPairId = signatureKeyPairId, + SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm, + SigningKey = accountKeysData.SignatureKeyPairData!.WrappedSigningKey, + VerifyingKey = accountKeysData.SignatureKeyPairData!.VerifyingKey, + RevisionDate = timestamp, + AccountRevisionDate = timestamp + }, + transaction: transaction, + commandType: CommandType.StoredProcedure); + + // Update user data that depends on cryptographic state + if (updateUserDataActions != null) + { + foreach (var action in updateUserDataActions) + { + await action(connection, transaction); + } + } + + await transaction.CommitAsync(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } + } + public async Task> GetManyAsync(IEnumerable ids) { using (var connection = new SqlConnection(ReadOnlyConnectionString)) @@ -324,6 +383,51 @@ public class UserRepository : Repository, IUserRepository return result.SingleOrDefault(); } + public async Task> GetPremiumAccessByIdsAsync(IEnumerable ids) + { + using (var connection = new SqlConnection(ReadOnlyConnectionString)) + { + var results = await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadPremiumAccessByIds]", + new { Ids = ids.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + + return results.ToList(); + } + } + + public async Task GetPremiumAccessAsync(Guid userId) + { + var result = await GetPremiumAccessByIdsAsync([userId]); + return result.SingleOrDefault(); + } + + public UpdateUserData SetKeyConnectorUserKey(Guid userId, string keyConnectorWrappedUserKey) + { + return async (connection, transaction) => + { + var timestamp = DateTime.UtcNow; + + await connection!.ExecuteAsync( + "[dbo].[User_UpdateKeyConnectorUserKey]", + new + { + Id = userId, + Key = keyConnectorWrappedUserKey, + // Key Connector does not use KDF, so we set some defaults + Kdf = KdfType.Argon2id, + KdfIterations = AuthConstants.ARGON2_ITERATIONS.Default, + KdfMemory = AuthConstants.ARGON2_MEMORY.Default, + KdfParallelism = AuthConstants.ARGON2_PARALLELISM.Default, + UsesKeyConnector = true, + RevisionDate = timestamp, + AccountRevisionDate = timestamp + }, + transaction: transaction, + commandType: CommandType.StoredProcedure); + }; + } + private async Task ProtectDataAndSaveAsync(User user, Func saveTask) { if (user == null) diff --git a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs index 292e99d6ad..869321f280 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs @@ -85,4 +85,19 @@ public class SecurityTaskRepository : Repository, ISecurityT return tasksList; } + + /// + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + await using var connection = new SqlConnection(ConnectionString); + await connection.ExecuteAsync( + $"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]", + new { CipherIds = cipherIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs index 47369f5e3d..93d8fe2d7d 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationEntityTypeConfiguration.cs @@ -18,7 +18,7 @@ public class OrganizationEntityTypeConfiguration : IEntityTypeConfiguration new { o.Id, o.Enabled }), - o => o.UseTotp); + o => new { o.UseTotp, o.UsersGetPremium }); builder.ToTable(nameof(Organization)); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationConfigurationEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationConfigurationEntityTypeConfiguration.cs index 935473deaa..bc57c8ed15 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationConfigurationEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationConfigurationEntityTypeConfiguration.cs @@ -1,4 +1,4 @@ -using Bit.Infrastructure.EntityFramework.AdminConsole.Models; +using Bit.Infrastructure.EntityFramework.Dirt.Models; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata.Builders; diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationEntityTypeConfiguration.cs index 3434d735d0..b14c156832 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/OrganizationIntegrationEntityTypeConfiguration.cs @@ -1,4 +1,4 @@ -using Bit.Infrastructure.EntityFramework.AdminConsole.Models; +using Bit.Infrastructure.EntityFramework.Dirt.Models; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata.Builders; diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegration.cs deleted file mode 100644 index 0f47d5947b..0000000000 --- a/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegration.cs +++ /dev/null @@ -1,16 +0,0 @@ -using AutoMapper; - -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Models; - -public class OrganizationIntegration : Core.AdminConsole.Entities.OrganizationIntegration -{ - public virtual required Organization Organization { get; set; } -} - -public class OrganizationIntegrationMapperProfile : Profile -{ - public OrganizationIntegrationMapperProfile() - { - CreateMap().ReverseMap(); - } -} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegrationConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegrationConfiguration.cs deleted file mode 100644 index 21b282f767..0000000000 --- a/src/Infrastructure.EntityFramework/AdminConsole/Models/OrganizationIntegrationConfiguration.cs +++ /dev/null @@ -1,16 +0,0 @@ -using AutoMapper; - -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Models; - -public class OrganizationIntegrationConfiguration : Core.AdminConsole.Entities.OrganizationIntegrationConfiguration -{ - public virtual required OrganizationIntegration OrganizationIntegration { get; set; } -} - -public class OrganizationIntegrationConfigurationMapperProfile : Profile -{ - public OrganizationIntegrationConfigurationMapperProfile() - { - CreateMap().ReverseMap(); - } -} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs deleted file mode 100644 index 5670b2ae9b..0000000000 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs +++ /dev/null @@ -1,29 +0,0 @@ -using AutoMapper; -using Bit.Core.Repositories; -using Bit.Infrastructure.EntityFramework.AdminConsole.Models; -using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries; -using Bit.Infrastructure.EntityFramework.Repositories; -using Microsoft.EntityFrameworkCore; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; - -public class OrganizationIntegrationRepository : - Repository, - IOrganizationIntegrationRepository -{ - public OrganizationIntegrationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) - : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationIntegrations) - { - } - - public async Task> GetManyByOrganizationAsync(Guid organizationId) - { - using (var scope = ServiceScopeFactory.CreateScope()) - { - var dbContext = GetDatabaseContext(scope); - var query = new OrganizationIntegrationReadManyByOrganizationIdQuery(organizationId); - return await query.Run(dbContext).ToListAsync(); - } - } -} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 200c4aa308..f2da58a1dd 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -112,7 +112,9 @@ public class OrganizationRepository : Repository GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index fae0598c1c..ae55099775 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -3,6 +3,7 @@ using AutoMapper; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -942,4 +943,42 @@ public class OrganizationUserRepository : Repository ConfirmOrganizationUserAsync(AcceptedOrganizationUserToConfirm organizationUserToConfirm) + { + using var scope = ServiceScopeFactory.CreateScope(); + await using var dbContext = GetDatabaseContext(scope); + + var result = await dbContext.OrganizationUsers + .Where(ou => ou.Id == organizationUserToConfirm.OrganizationUserId + && ou.Status == OrganizationUserStatusType.Accepted) + .ExecuteUpdateAsync(x => x + .SetProperty(y => y.Status, OrganizationUserStatusType.Confirmed) + .SetProperty(y => y.Key, organizationUserToConfirm.Key)); + + if (result <= 0) + { + return false; + } + + await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUserToConfirm.OrganizationUserId); + return true; + + } + +#nullable enable + + public async Task GetDetailsByOrganizationIdUserIdAsync(Guid organizationId, Guid userId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var view = new OrganizationUserUserDetailsViewQuery(); + var entity = await view.Run(dbContext).SingleOrDefaultAsync(ou => ou.OrganizationId == organizationId && ou.UserId == userId); + return entity; + } + } +#nullable disable + + } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs index 72c277f1d7..894fb255be 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs @@ -56,45 +56,6 @@ public class PolicyRepository : Repository> GetPolicyDetailsByUserId(Guid userId) - { - using var scope = ServiceScopeFactory.CreateScope(); - var dbContext = GetDatabaseContext(scope); - - var providerOrganizations = from pu in dbContext.ProviderUsers - where pu.UserId == userId - join po in dbContext.ProviderOrganizations - on pu.ProviderId equals po.ProviderId - select po; - - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on p.OrganizationId equals o.Id - where - p.Enabled && - o.Enabled && - o.UsePolicies && - ( - (ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) || - // Invited orgUsers do not have a UserId associated with them, so we have to match up their email - (ou.Status == OrganizationUserStatusType.Invited && ou.Email == dbContext.Users.Find(userId).Email) - ) - select new PolicyDetails - { - OrganizationUserId = ou.Id, - OrganizationId = p.OrganizationId, - PolicyType = p.Type, - PolicyData = p.Data, - OrganizationUserType = ou.Type, - OrganizationUserStatus = ou.Status, - OrganizationUserPermissionsData = ou.Permissions, - IsProvider = providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) - }; - return await query.ToListAsync(); - } - public async Task> GetPolicyDetailsByOrganizationIdAsync(Guid organizationId, PolicyType policyType) { using var scope = ServiceScopeFactory.CreateScope(); @@ -256,7 +217,7 @@ public class PolicyRepository : Repository new OrganizationPolicyDetails { diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs index 5474e3e217..8f9a38f9b6 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/ProviderUserRepository.cs @@ -96,6 +96,20 @@ public class ProviderUserRepository : return await query.ToArrayAsync(); } } + + public async Task> GetManyByManyUsersAsync(IEnumerable userIds) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + + var dbContext = GetDatabaseContext(scope); + + var query = from pu in dbContext.ProviderUsers + where pu.UserId != null && userIds.Contains(pu.UserId.Value) + select pu; + + return await query.ToArrayAsync(); + } + public async Task GetByProviderUserAsync(Guid providerId, Guid userId) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs index 26d3a128fc..f433e9096b 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs @@ -73,7 +73,9 @@ public class OrganizationUserOrganizationDetailsViewQuery : IQuery new ProviderUserOrganizationDetails { OrganizationId = x.po.OrganizationId, @@ -29,6 +31,9 @@ public class ProviderUserOrganizationDetailsViewQuery : IQuery().ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/Dirt/Models/OrganizationIntegrationConfiguration.cs b/src/Infrastructure.EntityFramework/Dirt/Models/OrganizationIntegrationConfiguration.cs new file mode 100644 index 0000000000..11632d6530 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Dirt/Models/OrganizationIntegrationConfiguration.cs @@ -0,0 +1,16 @@ +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.Dirt.Models; + +public class OrganizationIntegrationConfiguration : Core.Dirt.Entities.OrganizationIntegrationConfiguration +{ + public virtual required OrganizationIntegration OrganizationIntegration { get; set; } +} + +public class OrganizationIntegrationConfigurationMapperProfile : Profile +{ + public OrganizationIntegrationConfigurationMapperProfile() + { + CreateMap().ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/EventRepository.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/EventRepository.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/EventRepository.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/EventRepository.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs similarity index 67% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs index fc391b958c..b0d545d3c3 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationConfigurationRepository.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationConfigurationRepository.cs @@ -1,32 +1,33 @@ using AutoMapper; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Infrastructure.EntityFramework.AdminConsole.Models; -using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries; +using Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; using Bit.Infrastructure.EntityFramework.Repositories; -using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; +using OrganizationIntegrationConfiguration = Bit.Core.Dirt.Entities.OrganizationIntegrationConfiguration; -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories; -public class OrganizationIntegrationConfigurationRepository : Repository, IOrganizationIntegrationConfigurationRepository +public class OrganizationIntegrationConfigurationRepository : Repository, IOrganizationIntegrationConfigurationRepository { public OrganizationIntegrationConfigurationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : base(serviceScopeFactory, mapper, context => context.OrganizationIntegrationConfigurations) { } - public async Task> GetConfigurationDetailsAsync( - Guid organizationId, - IntegrationType integrationType, - EventType eventType) + public async Task> + GetManyByEventTypeOrganizationIdIntegrationType(EventType eventType, Guid organizationId, + IntegrationType integrationType) { using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); var query = new OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery( - organizationId, eventType, integrationType + organizationId, + eventType, + integrationType ); return await query.Run(dbContext).ToListAsync(); } @@ -42,7 +43,7 @@ public class OrganizationIntegrationConfigurationRepository : Repository> GetManyByIntegrationAsync( + public async Task> GetManyByIntegrationAsync( Guid organizationIntegrationId) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationRepository.cs new file mode 100644 index 0000000000..cbcd574854 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationIntegrationRepository.cs @@ -0,0 +1,41 @@ +using AutoMapper; +using Bit.Core.Dirt.Repositories; +using Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using OrganizationIntegration = Bit.Core.Dirt.Entities.OrganizationIntegration; + +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories; + +public class OrganizationIntegrationRepository : + Repository, + IOrganizationIntegrationRepository +{ + public OrganizationIntegrationRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) + : base(serviceScopeFactory, mapper, (DatabaseContext context) => context.OrganizationIntegrations) + { + } + + public async Task> GetManyByOrganizationAsync(Guid organizationId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new OrganizationIntegrationReadManyByOrganizationIdQuery(organizationId); + return await query.Run(dbContext).ToListAsync(); + } + } + + public async Task GetByTeamsConfigurationTenantIdTeamId( + string tenantId, + string teamId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery(tenantId: tenantId, teamId: teamId); + return await query.Run(dbContext).SingleOrDefaultAsync(); + } + } +} diff --git a/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs index 525c5a479d..d08e70c353 100644 --- a/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/OrganizationReportRepository.cs @@ -4,6 +4,7 @@ using AutoMapper; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Infrastructure.EntityFramework.Repositories; using LinqToDB; @@ -184,4 +185,31 @@ public class OrganizationReportRepository : return Mapper.Map(updatedReport); } } + + public Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + return dbContext.OrganizationReports + .Where(p => p.Id == reportId) + .UpdateAsync(p => new Models.OrganizationReport + { + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }); + } + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByCipherIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByCipherIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByCipherIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByCipherIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdActingUserIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProjectIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProjectIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProjectIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProjectIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProviderIdActingUserIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProviderIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProviderIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByProviderIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByProviderIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageBySecretIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageBySecretIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageBySecretIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageBySecretIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByUserIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByUserIdQuery.cs similarity index 100% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByUserIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/EventReadPageByUserIdQuery.cs diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs similarity index 50% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs index b4441c5084..25fd06c04d 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery.cs @@ -1,31 +1,24 @@ -#nullable enable - +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories.Queries; -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; -public class OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery : IQuery +public class OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery( + Guid organizationId, + EventType eventType, + IntegrationType integrationType) + : IQuery { - private readonly Guid _organizationId; - private readonly EventType _eventType; - private readonly IntegrationType _integrationType; - - public OrganizationIntegrationConfigurationDetailsReadManyByEventTypeOrganizationIdIntegrationTypeQuery(Guid organizationId, EventType eventType, IntegrationType integrationType) - { - _organizationId = organizationId; - _eventType = eventType; - _integrationType = integrationType; - } - public IQueryable Run(DatabaseContext dbContext) { var query = from oic in dbContext.OrganizationIntegrationConfigurations - join oi in dbContext.OrganizationIntegrations on oic.OrganizationIntegrationId equals oi.Id into oioic - from oi in dbContext.OrganizationIntegrations - where oi.OrganizationId == _organizationId && - oi.Type == _integrationType && - oic.EventType == _eventType + join oi in dbContext.OrganizationIntegrations on oic.OrganizationIntegrationId equals oi.Id + where oi.OrganizationId == organizationId && + oi.Type == integrationType && + (oic.EventType == eventType || oic.EventType == null) select new OrganizationIntegrationConfigurationDetails() { Id = oic.Id, diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs similarity index 82% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs index 8141292c81..4d5be520d2 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationDetailsReadManyQuery.cs @@ -1,8 +1,8 @@ -#nullable enable +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories.Queries; -using Bit.Core.Models.Data.Organizations; - -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; public class OrganizationIntegrationConfigurationDetailsReadManyQuery : IQuery { diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs similarity index 91% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs index 3ed3a48723..3ae2f5f66d 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery.cs @@ -1,8 +1,8 @@ -using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Entities; using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries; +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; public class OrganizationIntegrationConfigurationReadManyByOrganizationIntegrationIdQuery : IQuery { diff --git a/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs new file mode 100644 index 0000000000..fd06c6d296 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs @@ -0,0 +1,36 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories.Queries; + +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; + +public class OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery : IQuery +{ + private readonly string _tenantId; + private readonly string _teamId; + + public OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery(string tenantId, string teamId) + { + _tenantId = tenantId; + _teamId = teamId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = + from oi in dbContext.OrganizationIntegrations + where oi.Type == IntegrationType.Teams && + oi.Configuration != null && + oi.Configuration.Contains($"\"TenantId\":\"{_tenantId}\"") && + oi.Configuration.Contains($"\"id\":\"{_teamId}\"") + select new OrganizationIntegration() + { + Id = oi.Id, + OrganizationId = oi.OrganizationId, + Type = oi.Type, + Configuration = oi.Configuration, + }; + return query; + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs similarity index 88% rename from src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs rename to src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs index df87ad0bc1..477983ebab 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs +++ b/src/Infrastructure.EntityFramework/Dirt/Repositories/Queries/OrganizationIntegrationReadManyByOrganizationIdQuery.cs @@ -1,8 +1,8 @@ -using Bit.Core.AdminConsole.Entities; +using Bit.Core.Dirt.Entities; using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries; +namespace Bit.Infrastructure.EntityFramework.Dirt.Repositories.Queries; public class OrganizationIntegrationReadManyByOrganizationIdQuery : IQuery { diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 7a6507230e..3c35df2a82 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -108,6 +108,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs new file mode 100644 index 0000000000..aa10a73a88 --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs @@ -0,0 +1,22 @@ +using Bit.Infrastructure.EntityFramework.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata.Builders; + +namespace Bit.Infrastructure.EntityFramework.Configurations; + +public class UserSignatureKeyPairEntityTypeConfiguration : IEntityTypeConfiguration +{ + public void Configure(EntityTypeBuilder builder) + { + builder + .Property(s => s.Id) + .ValueGeneratedNever(); + + builder + .HasIndex(s => s.UserId) + .IsUnique() + .IsClustered(false); + + builder.ToTable(nameof(UserSignatureKeyPair)); + } +} diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs b/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs new file mode 100644 index 0000000000..b2bd8a1345 --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs @@ -0,0 +1,19 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.Models; + +public class UserSignatureKeyPair : Core.KeyManagement.Entities.UserSignatureKeyPair +{ + public virtual User User { get; set; } +} + +public class UserSignatureKeyPairMapperProfile : Profile +{ + public UserSignatureKeyPairMapperProfile() + { + CreateMap().ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..04f055501d --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs @@ -0,0 +1,66 @@ + +using AutoMapper; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Utilities; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.KeyManagement.Repositories; + +public class UserSignatureKeyPairRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : Repository(serviceScopeFactory, mapper, context => context.UserSignatureKeyPairs), IUserSignatureKeyPairRepository +{ + public async Task GetByUserIdAsync(Guid userId) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var signingKeys = await dbContext.UserSignatureKeyPairs.FirstOrDefaultAsync(x => x.UserId == userId); + if (signingKeys == null) + { + return null; + } + + return signingKeys.ToSignatureKeyPairData(); + } + + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signingKeys) + { + return async (_, _) => + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var entity = new Models.UserSignatureKeyPair + { + Id = CoreHelpers.GenerateComb(), + UserId = userId, + SignatureAlgorithm = signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + await dbContext.UserSignatureKeyPairs.AddAsync(entity); + await dbContext.SaveChangesAsync(); + }; + } + + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signingKeys) + { + return async (_, _) => + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.UserSignatureKeyPairs.FirstOrDefaultAsync(x => x.UserId == grantorId); + if (entity != null) + { + entity.SignatureAlgorithm = signingKeys.SignatureAlgorithm; + entity.SigningKey = signingKeys.WrappedSigningKey; + entity.VerifyingKey = signingKeys.VerifyingKey; + entity.RevisionDate = DateTime.UtcNow; + await dbContext.SaveChangesAsync(); + } + }; + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 7446abdd97..b748a26db2 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -63,6 +63,7 @@ public class DatabaseContext : DbContext public DbSet Policies { get; set; } public DbSet Providers { get; set; } public DbSet Secret { get; set; } + public DbSet SecretVersion { get; set; } public DbSet ServiceAccount { get; set; } public DbSet Project { get; set; } public DbSet ProviderUsers { get; set; } @@ -73,6 +74,7 @@ public class DatabaseContext : DbContext public DbSet TaxRates { get; set; } public DbSet Transactions { get; set; } public DbSet Users { get; set; } + public DbSet UserSignatureKeyPairs { get; set; } public DbSet AuthRequests { get; set; } public DbSet OrganizationDomains { get; set; } public DbSet WebAuthnCredentials { get; set; } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs index 0ddf80130e..d337a5e856 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationDomainRepository.cs @@ -177,5 +177,25 @@ public class OrganizationDomainRepository : Repository>(verifiedDomains); } + public async Task HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(string domainName, Guid? excludeOrganizationId = null) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var query = from od in dbContext.OrganizationDomains + join o in dbContext.Organizations on od.OrganizationId equals o.Id + join p in dbContext.Policies on o.Id equals p.OrganizationId + where od.DomainName == domainName + && od.VerifiedDate != null + && o.Enabled + && o.UsePolicies + && o.UseOrganizationDomains + && (!excludeOrganizationId.HasValue || o.Id != excludeOrganizationId.Value) + && p.Type == Core.AdminConsole.Enums.PolicyType.BlockClaimedDomainAccountCreation + && p.Enabled + select od; + + return await query.AnyAsync(); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index 809704edb7..56d64094d0 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -1,4 +1,8 @@ using AutoMapper; +using Bit.Core; +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Models.Data; using Bit.Core.Repositories; @@ -241,6 +245,80 @@ public class UserRepository : Repository, IUserR await transaction.CommitAsync(); } + public async Task SetV2AccountCryptographicStateAsync( + Guid userId, + UserAccountKeysData accountKeysData, + IEnumerable? updateUserDataActions = null) + { + if (!accountKeysData.IsV2Encryption()) + { + throw new ArgumentException("Provided account keys data is not valid V2 encryption data.", nameof(accountKeysData)); + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + await using var transaction = await dbContext.Database.BeginTransactionAsync(); + + // Update user + var userEntity = await dbContext.Users.FindAsync(userId); + if (userEntity == null) + { + throw new ArgumentException("User not found", nameof(userId)); + } + + // Update public key encryption key pair + var timestamp = DateTime.UtcNow; + + userEntity.RevisionDate = timestamp; + userEntity.AccountRevisionDate = timestamp; + + // V1 + V2 user crypto changes + userEntity.PublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.PublicKey; + userEntity.PrivateKey = accountKeysData.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + + userEntity.SecurityState = accountKeysData.SecurityStateData!.SecurityState; + userEntity.SecurityVersion = accountKeysData.SecurityStateData.SecurityVersion; + userEntity.SignedPublicKey = accountKeysData.PublicKeyEncryptionKeyPairData.SignedPublicKey; + + // Replace existing keypair if it exists + var existingKeyPair = await dbContext.UserSignatureKeyPairs + .FirstOrDefaultAsync(x => x.UserId == userId); + if (existingKeyPair != null) + { + existingKeyPair.SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm; + existingKeyPair.SigningKey = accountKeysData.SignatureKeyPairData.WrappedSigningKey; + existingKeyPair.VerifyingKey = accountKeysData.SignatureKeyPairData.VerifyingKey; + existingKeyPair.RevisionDate = timestamp; + } + else + { + var newKeyPair = new UserSignatureKeyPair + { + UserId = userId, + SignatureAlgorithm = accountKeysData.SignatureKeyPairData!.SignatureAlgorithm, + SigningKey = accountKeysData.SignatureKeyPairData.WrappedSigningKey, + VerifyingKey = accountKeysData.SignatureKeyPairData.VerifyingKey, + CreationDate = timestamp, + RevisionDate = timestamp + }; + newKeyPair.SetNewId(); + await dbContext.UserSignatureKeyPairs.AddAsync(newKeyPair); + } + + await dbContext.SaveChangesAsync(); + + // Update additional user data within the same transaction + if (updateUserDataActions != null) + { + foreach (var action in updateUserDataActions) + { + await action(); + } + } + await transaction.CommitAsync(); + } + public async Task> GetManyAsync(IEnumerable ids) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -275,6 +353,36 @@ public class UserRepository : Repository, IUserR return result.FirstOrDefault(); } + public async Task> GetPremiumAccessByIdsAsync(IEnumerable ids) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + var users = await dbContext.Users + .Where(x => ids.Contains(x.Id)) + .Include(u => u.OrganizationUsers) + .ThenInclude(ou => ou.Organization) + .ToListAsync(); + + return users.Select(user => new UserPremiumAccess + { + Id = user.Id, + PersonalPremium = user.Premium, + OrganizationPremium = user.OrganizationUsers + .Any(ou => ou.Organization != null && + ou.Organization.Enabled == true && + ou.Organization.UsersGetPremium == true) + }).ToList(); + } + } + + public async Task GetPremiumAccessAsync(Guid userId) + { + var result = await GetPremiumAccessByIdsAsync([userId]); + return result.FirstOrDefault(); + } + public override async Task DeleteAsync(Core.Entities.User user) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -373,6 +481,35 @@ public class UserRepository : Repository, IUserR } } + public UpdateUserData SetKeyConnectorUserKey(Guid userId, string keyConnectorWrappedUserKey) + { + return async (_, _) => + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var userEntity = await dbContext.Users.FindAsync(userId); + if (userEntity == null) + { + throw new ArgumentException("User not found", nameof(userId)); + } + + var timestamp = DateTime.UtcNow; + + userEntity.Key = keyConnectorWrappedUserKey; + // Key Connector does not use KDF, so we set some defaults + userEntity.Kdf = KdfType.Argon2id; + userEntity.KdfIterations = AuthConstants.ARGON2_ITERATIONS.Default; + userEntity.KdfMemory = AuthConstants.ARGON2_MEMORY.Default; + userEntity.KdfParallelism = AuthConstants.ARGON2_PARALLELISM.Default; + userEntity.UsesKeyConnector = true; + userEntity.RevisionDate = timestamp; + userEntity.AccountRevisionDate = timestamp; + + await dbContext.SaveChangesAsync(); + }; + } + private static void MigrateDefaultUserCollectionsToShared(DatabaseContext dbContext, IEnumerable userIds) { var defaultCollections = (from c in dbContext.Collections diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs new file mode 100644 index 0000000000..069c7e2450 --- /dev/null +++ b/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs @@ -0,0 +1,42 @@ +using Bit.Infrastructure.EntityFramework.SecretsManager.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata.Builders; + +namespace Bit.Infrastructure.EntityFramework.SecretsManager.Configurations; + +public class SecretVersionEntityTypeConfiguration : IEntityTypeConfiguration +{ + public void Configure(EntityTypeBuilder builder) + { + builder.Property(sv => sv.Id) + .ValueGeneratedNever(); + + builder.HasKey(sv => sv.Id) + .IsClustered(); + + builder.Property(sv => sv.Value) + .IsRequired(); + + builder.Property(sv => sv.VersionDate) + .IsRequired(); + + builder.HasOne(sv => sv.EditorServiceAccount) + .WithMany() + .HasForeignKey(sv => sv.EditorServiceAccountId) + .OnDelete(DeleteBehavior.SetNull); + + builder.HasOne(sv => sv.EditorOrganizationUser) + .WithMany() + .HasForeignKey(sv => sv.EditorOrganizationUserId) + .OnDelete(DeleteBehavior.SetNull); + + builder.HasIndex(sv => sv.SecretId) + .HasDatabaseName("IX_SecretVersion_SecretId"); + + builder.HasIndex(sv => sv.EditorServiceAccountId) + .HasDatabaseName("IX_SecretVersion_EditorServiceAccountId"); + + builder.HasIndex(sv => sv.EditorOrganizationUserId) + .HasDatabaseName("IX_SecretVersion_EditorOrganizationUserId"); + } +} diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs b/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs index 5992f32135..09d8c389df 100644 --- a/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs +++ b/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs @@ -13,6 +13,7 @@ public class Secret : Core.SecretsManager.Entities.Secret public virtual ICollection UserAccessPolicies { get; set; } public virtual ICollection GroupAccessPolicies { get; set; } public virtual ICollection ServiceAccountAccessPolicies { get; set; } + public virtual ICollection SecretVersions { get; set; } } public class SecretMapperProfile : Profile diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs b/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs new file mode 100644 index 0000000000..d4a364ab0f --- /dev/null +++ b/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs @@ -0,0 +1,24 @@ +#nullable enable + +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.SecretsManager.Models; + +public class SecretVersion : Core.SecretsManager.Entities.SecretVersion +{ + public Secret? Secret { get; set; } + + public ServiceAccount? EditorServiceAccount { get; set; } + + public Bit.Infrastructure.EntityFramework.Models.OrganizationUser? EditorOrganizationUser { get; set; } +} + +public class SecretVersionMapperProfile : Profile +{ + public SecretVersionMapperProfile() + { + CreateMap() + .PreserveReferences() + .ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index 3c45afe530..ebe39852f4 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -704,6 +704,9 @@ public class CipherRepository : Repository + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var cipherIdsList = cipherIds.ToList(); + + await dbContext.SecurityTasks + .Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed) + .ExecuteUpdateAsync(st => st + .SetProperty(s => s.Status, SecurityTaskStatus.Completed) + .SetProperty(s => s.RevisionDate, DateTime.UtcNow)); + } } diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index 94aa14eaf6..40dd8d22d4 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -1,34 +1,26 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Azure.Storage.Queues; +using Azure.Storage.Queues; using Bit.Core.Settings; using Bit.Core.Utilities; -using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; + private readonly HubHelpers _hubHelpers; private readonly GlobalSettings _globalSettings; - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; + private Task? _executingTask; + private CancellationTokenSource? _cts; public AzureQueueHostedService( ILogger logger, - IHubContext hubContext, - IHubContext anonymousHubContext, + HubHelpers hubHelpers, GlobalSettings globalSettings) { _logger = logger; - _hubContext = hubContext; + _hubHelpers = hubHelpers; _globalSettings = globalSettings; - _anonymousHubContext = anonymousHubContext; } public Task StartAsync(CancellationToken cancellationToken) @@ -44,32 +36,39 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } + _logger.LogWarning("Stopping service."); - _cts.Cancel(); + _cts?.Cancel(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } public void Dispose() - { } + { + } private async Task ExecuteAsync(CancellationToken cancellationToken) { - _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); + var queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); while (!cancellationToken.IsCancellationRequested) { try { - var messages = await _queueClient.ReceiveMessagesAsync(32); + var messages = await queueClient.ReceiveMessagesAsync(32, cancellationToken: cancellationToken); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { try { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, _anonymousHubContext, _logger, cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + var decodedMessage = message.DecodeMessageText(); + if (!string.IsNullOrWhiteSpace(decodedMessage)) + { + await _hubHelpers.SendNotificationToHubAsync(decodedMessage, cancellationToken); + } + + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } catch (Exception e) { @@ -77,7 +76,8 @@ public class AzureQueueHostedService : IHostedService, IDisposable message.MessageId, message.DequeueCount); if (message.DequeueCount > 2) { - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 7debd51df7..c663102b56 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -1,36 +1,30 @@ -using System.Text; +#nullable enable +using System.Text; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; +namespace Bit.Notifications.Controllers; [Authorize("Internal")] public class SendController : Controller { - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; - private readonly ILogger _logger; + private readonly HubHelpers _hubHelpers; - public SendController(IHubContext hubContext, IHubContext anonymousHubContext, ILogger logger) + public SendController(HubHelpers hubHelpers) { - _hubContext = hubContext; - _anonymousHubContext = anonymousHubContext; - _logger = logger; + _hubHelpers = hubHelpers; } [HttpPost("~/send")] [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() + public async Task PostSendAsync() { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + using var reader = new StreamReader(Request.Body, Encoding.UTF8); + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) - { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext, _anonymousHubContext, _logger); - } + await _hubHelpers.SendNotificationToHubAsync(notificationJson); } } } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 69d5bdc958..bc03bb46df 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -1,31 +1,39 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; +using System.Text.Json; using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; -public static class HubHelpers +public class HubHelpers { - private static JsonSerializerOptions _deserializerOptions = - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly JsonSerializerOptions _deserializerOptions = new() { PropertyNameCaseInsensitive = true }; private static readonly string _receiveMessageMethod = "ReceiveMessage"; - public static async Task SendNotificationToHubAsync( - string notificationJson, - IHubContext hubContext, + private readonly IHubContext _hubContext; + private readonly IHubContext _anonymousHubContext; + private readonly ILogger _logger; + + public HubHelpers(IHubContext hubContext, IHubContext anonymousHubContext, - ILogger logger, - CancellationToken cancellationToken = default(CancellationToken) - ) + ILogger logger) + { + _hubContext = hubContext; + _anonymousHubContext = anonymousHubContext; + _logger = logger; + } + + public async Task SendNotificationToHubAsync(string notificationJson, CancellationToken cancellationToken = default) { var notification = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - logger.LogInformation("Sending notification: {NotificationType}", notification.Type); + if (notification is null) + { + return; + } + + _logger.LogInformation("Sending notification: {NotificationType}", notification.Type); switch (notification.Type) { case PushType.SyncCipherUpdate: @@ -35,14 +43,19 @@ public static class HubHelpers var cipherNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (cipherNotification is null) + { + break; + } + if (cipherNotification.Payload.UserId.HasValue) { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(cipherNotification.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients + await _hubContext.Clients .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } @@ -54,7 +67,12 @@ public static class HubHelpers var folderNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + if (folderNotification is null) + { + break; + } + + await _hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: @@ -64,9 +82,14 @@ public static class HubHelpers case PushType.SyncSettings: case PushType.LogOut: var userNotification = - JsonSerializer.Deserialize>( + JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + if (userNotification is null) + { + break; + } + + await _hubContext.Clients.User(userNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: @@ -75,58 +98,102 @@ public static class HubHelpers var sendNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + if (sendNotification is null) + { + break; + } + + await _hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) + if (authRequestResponseNotification is null) + { + break; + } + + await _anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) .SendAsync("AuthRequestResponseRecieved", authRequestResponseNotification, cancellationToken); break; case PushType.AuthRequest: var authRequestNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) + if (authRequestNotification is null) + { + break; + } + + await _hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncOrganizationStatusChanged: var orgStatusNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) + if (orgStatusNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) .SendAsync(_receiveMessageMethod, orgStatusNotification, cancellationToken); break; case PushType.SyncOrganizationCollectionSettingChanged: var organizationCollectionSettingsChangedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification.Payload.OrganizationId)) - .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, cancellationToken); + if (organizationCollectionSettingsChangedNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification + .Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, + cancellationToken); break; case PushType.OrganizationBankAccountVerified: var organizationBankAccountVerifiedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationBankAccountVerifiedNotification.Payload.OrganizationId)) + if (organizationBankAccountVerifiedNotification is null) + { + break; + } + + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationBankAccountVerifiedNotification.Payload.OrganizationId)) .SendAsync(_receiveMessageMethod, organizationBankAccountVerifiedNotification, cancellationToken); break; case PushType.ProviderBankAccountVerified: var providerBankAccountVerifiedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(providerBankAccountVerifiedNotification.Payload.AdminId.ToString()) + if (providerBankAccountVerifiedNotification is null) + { + break; + } + + await _hubContext.Clients.User(providerBankAccountVerifiedNotification.Payload.AdminId.ToString()) .SendAsync(_receiveMessageMethod, providerBankAccountVerifiedNotification, cancellationToken); break; case PushType.Notification: case PushType.NotificationStatus: var notificationData = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (notificationData is null) + { + break; + } + if (notificationData.Payload.InstallationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( notificationData.Payload.InstallationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } @@ -134,32 +201,56 @@ public static class HubHelpers { if (notificationData.Payload.ClientType == ClientType.All) { - await hubContext.Clients.User(notificationData.Payload.UserId.ToString()) + await _hubContext.Clients.User(notificationData.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } else { - await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + await _hubContext.Clients.Group(NotificationsHub.GetUserGroup( notificationData.Payload.UserId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } } else if (notificationData.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( notificationData.Payload.OrganizationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } break; case PushType.RefreshSecurityTasks: - var pendingTasksData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) + var pendingTasksData = + JsonSerializer.Deserialize>(notificationJson, + _deserializerOptions); + if (pendingTasksData is null) + { + break; + } + + await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); break; + case PushType.PolicyChanged: + await policyChangedNotificationHandler(notificationJson, cancellationToken); + break; default: - logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); + _logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); break; } } + + private async Task policyChangedNotificationHandler(string notificationJson, CancellationToken cancellationToken) + { + var policyData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); + if (policyData is null) + { + return; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(policyData.Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, policyData, cancellationToken); + + } } diff --git a/src/Notifications/Program.cs b/src/Notifications/Program.cs index 072c2404c4..2792391729 100644 --- a/src/Notifications/Program.cs +++ b/src/Notifications/Program.cs @@ -1,5 +1,4 @@ using Bit.Core.Utilities; -using Serilog.Events; namespace Bit.Notifications; @@ -13,37 +12,8 @@ public class Program .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); - webBuilder.ConfigureLogging((hostingContext, logging) => - logging.AddSerilog(hostingContext, (e, globalSettings) => - { - var context = e.Properties["SourceContext"].ToString(); - if (context.Contains("Duende.IdentityServer.Validation.TokenValidator") || - context.Contains("Duende.IdentityServer.Validation.TokenRequestValidator")) - { - return e.Level >= globalSettings.MinLogLevel.NotificationsSettings.IdentityToken; - } - - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text == "Failed connection handshake.") - { - return false; - } - - if (e.Level == LogEventLevel.Error && - e.MessageTemplate.Text.StartsWith("Failed writing message.")) - { - return false; - } - - if (e.Level == LogEventLevel.Warning && - e.MessageTemplate.Text.StartsWith("Heartbeat took longer")) - { - return false; - } - - return e.Level >= globalSettings.MinLogLevel.NotificationsSettings.Default; - })); }) + .AddSerilogFileLogging() .Build() .Run(); } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index eb3c3f8682..65904ea698 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -61,6 +61,7 @@ public class Startup } services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Mvc services.AddMvc(); @@ -81,11 +82,9 @@ public class Startup public void Configure( IApplicationBuilder app, IWebHostEnvironment env, - IHostApplicationLifetime appLifetime, GlobalSettings globalSettings) { IdentityModelEventSource.ShowPII = true; - app.UseSerilog(env, appLifetime, globalSettings); // Add general security headers app.UseMiddleware(); diff --git a/src/Notifications/appsettings.json b/src/Notifications/appsettings.json index 020d98cbd6..e36ec02dad 100644 --- a/src/Notifications/appsettings.json +++ b/src/Notifications/appsettings.json @@ -18,9 +18,6 @@ "connectionString": "SECRET", "applicationCacheTopicName": "SECRET" }, - "sentry": { - "dsn": "SECRET" - }, "amazon": { "accessKeyId": "SECRET", "accessKeySecret": "SECRET", diff --git a/src/SharedWeb/SharedWeb.csproj b/src/SharedWeb/SharedWeb.csproj index 8bffa285fc..d8dc61178d 100644 --- a/src/SharedWeb/SharedWeb.csproj +++ b/src/SharedWeb/SharedWeb.csproj @@ -7,6 +7,7 @@ + diff --git a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs index 332aa6838c..aba1a6a8dc 100644 --- a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs @@ -75,7 +75,7 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute else { var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); errorMessage = "An unhandled server error has occurred."; context.HttpContext.Response.StatusCode = 500; } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 58ce0466c3..91047d98bc 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -6,11 +6,9 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; -using Azure.Messaging.ServiceBus; using Bit.Core; using Bit.Core.AdminConsole.AbilitiesCache; using Bit.Core.AdminConsole.Models.Business.Tokenables; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; @@ -19,7 +17,6 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Auth.IdentityServer; -using Bit.Core.Auth.LoginFeatures; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; @@ -37,6 +34,9 @@ using Bit.Core.KeyManagement; using Bit.Core.NotificationCenter; using Bit.Core.OrganizationFeatures; using Bit.Core.Platform; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; +using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Platform.Push; using Bit.Core.Platform.PushRegistration.Internal; using Bit.Core.Repositories; @@ -45,6 +45,7 @@ using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories.Noop; using Bit.Core.Services; using Bit.Core.Services.Implementations; +using Bit.Core.Services.Mail; using Bit.Core.Settings; using Bit.Core.Tokens; using Bit.Core.Tools.ImportFeatures; @@ -77,7 +78,9 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Microsoft.OpenApi.Models; using StackExchange.Redis; +using Swashbuckle.AspNetCore.SwaggerGen; using NoopRepos = Bit.Core.Repositories.Noop; using Role = Bit.Core.Entities.Role; using TableStorageRepos = Bit.Core.Repositories.TableStorage; @@ -129,7 +132,6 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); - services.AddLoginServices(); services.AddScoped(); services.AddVaultServices(); services.AddReportingServices(); @@ -236,11 +238,14 @@ public static class ServiceCollectionExtensions PrivateKey = globalSettings.Braintree.PrivateKey }; }); - services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); + // Legacy mailer service services.AddSingleton(); services.AddSingleton(); + // Modern mailers + services.AddMailer(); services.AddSingleton(); services.AddSingleton(_ => { @@ -330,6 +335,7 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); } @@ -512,98 +518,6 @@ public static class ServiceCollectionExtensions return globalSettings; } - public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) - { - services.TryAddKeyedSingleton("storage"); - - if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.EventTopicName)) - { - services.TryAddSingleton(); - services.TryAddKeyedSingleton("broadcast"); - } - else - { - services.TryAddKeyedSingleton("broadcast"); - } - } - else if (globalSettings.SelfHosted) - { - services.TryAddKeyedSingleton("storage"); - - if (IsRabbitMqEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddKeyedSingleton("broadcast"); - } - else - { - services.TryAddKeyedSingleton("broadcast"); - } - } - else - { - services.TryAddKeyedSingleton("storage"); - services.TryAddKeyedSingleton("broadcast"); - } - - services.TryAddScoped(); - return services; - } - - public static IServiceCollection AddAzureServiceBusListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsAzureServiceBusEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddRabbitMqListeners(this IServiceCollection services, GlobalSettings globalSettings) - { - if (!IsRabbitMqEnabled(globalSettings)) - { - return services; - } - - services.TryAddSingleton(); - services.TryAddSingleton(); - services.TryAddSingleton(); - - services.AddEventIntegrationServices(globalSettings); - - return services; - } - - public static IServiceCollection AddSlackService(this IServiceCollection services, GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.Slack.ClientId) && - CoreHelpers.SettingHasValue(globalSettings.Slack.ClientSecret) && - CoreHelpers.SettingHasValue(globalSettings.Slack.Scopes)) - { - services.AddHttpClient(SlackService.HttpClientName); - services.TryAddSingleton(); - } - else - { - services.TryAddSingleton(); - } - - return services; - } - public static void UseDefaultMiddleware(this IApplicationBuilder app, IWebHostEnvironment env, GlobalSettings globalSettings) { @@ -617,7 +531,7 @@ public static class ServiceCollectionExtensions ForwardedHeaders = ForwardedHeaders.XForwardedFor | ForwardedHeaders.XForwardedProto }; - if (!globalSettings.UnifiedDeployment) + if (!globalSettings.LiteDeployment) { // Trust the X-Forwarded-Host header of the nginx docker container try @@ -850,180 +764,60 @@ public static class ServiceCollectionExtensions return (provider, connectionString); } - private static IServiceCollection AddAzureServiceBusIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration + /// + /// Adds a server with its corresponding OAuth2 client credentials security definition and requirement. + /// + /// The SwaggerGen configuration + /// Unique identifier for this server (e.g., "us-server", "eu-server") + /// The API server URL + /// The identity server token URL + /// Human-readable description for the server + public static void AddSwaggerServerWithSecurity( + this SwaggerGenOptions config, + string serverId, + string serverUrl, + string identityTokenUrl, + string serverDescription) { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - configurationCache: provider.GetRequiredService(), - userRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - logger: provider.GetRequiredService>>() - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.EventPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusIntegrationListenerService( - configuration: listenerConfiguration, - handler: provider.GetRequiredService>(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, - MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static IServiceCollection AddEventIntegrationServices(this IServiceCollection services, - GlobalSettings globalSettings) - { - // Add common services - services.TryAddSingleton(); - services.TryAddSingleton(provider => - provider.GetRequiredService()); - services.AddHostedService(provider => provider.GetRequiredService()); - services.TryAddSingleton(); - services.TryAddKeyedSingleton("persistent"); - - // Add services in support of handlers - services.AddSlackService(globalSettings); - services.TryAddSingleton(TimeProvider.System); - services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); - services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); - - // Add integration handlers - services.TryAddSingleton, SlackIntegrationHandler>(); - services.TryAddSingleton, WebhookIntegrationHandler>(); - services.TryAddSingleton, DatadogIntegrationHandler>(); - - var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); - var slackConfiguration = new SlackListenerConfiguration(globalSettings); - var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); - var hecConfiguration = new HecListenerConfiguration(globalSettings); - var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); - - if (IsRabbitMqEnabled(globalSettings)) + // Add server + config.AddServer(new OpenApiServer { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredService(), - configuration: repositoryConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddRabbitMqIntegration(slackConfiguration); - services.AddRabbitMqIntegration(webhookConfiguration); - services.AddRabbitMqIntegration(hecConfiguration); - services.AddRabbitMqIntegration(datadogConfiguration); - } + Url = serverUrl, + Description = serverDescription + }); - if (IsAzureServiceBusEnabled(globalSettings)) + // Add security definition + config.AddSecurityDefinition(serverId, new OpenApiSecurityScheme { - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new AzureServiceBusEventListenerService( - configuration: repositoryConfiguration, - handler: provider.GetRequiredService(), - serviceBusService: provider.GetRequiredService(), - serviceBusOptions: new ServiceBusProcessorOptions() - { - PrefetchCount = repositoryConfiguration.EventPrefetchCount, - MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls - }, - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.AddAzureServiceBusIntegration(slackConfiguration); - services.AddAzureServiceBusIntegration(webhookConfiguration); - services.AddAzureServiceBusIntegration(hecConfiguration); - services.AddAzureServiceBusIntegration(datadogConfiguration); - } + Type = SecuritySchemeType.OAuth2, + Description = $"**Use this option if you've selected the {serverDescription}**", + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri(identityTokenUrl), + Scopes = new Dictionary + { + { ApiScopes.ApiOrganization, $"Organization APIs ({serverDescription})" }, + }, + } + }, + }); - return services; - } - - private static IServiceCollection AddRabbitMqIntegration(this IServiceCollection services, - TListenerConfig listenerConfiguration) - where TConfig : class - where TListenerConfig : IIntegrationListenerConfiguration - { - services.TryAddKeyedSingleton(serviceKey: listenerConfiguration.RoutingKey, implementationFactory: (provider, _) => - new EventIntegrationHandler( - integrationType: listenerConfiguration.IntegrationType, - eventIntegrationPublisher: provider.GetRequiredService(), - integrationFilterService: provider.GetRequiredService(), - configurationCache: provider.GetRequiredService(), - userRepository: provider.GetRequiredService(), - organizationRepository: provider.GetRequiredService(), - logger: provider.GetRequiredService>>() - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqEventListenerService( - handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService() - ) - ) - ); - services.TryAddEnumerable(ServiceDescriptor.Singleton>(provider => - new RabbitMqIntegrationListenerService( - handler: provider.GetRequiredService>(), - configuration: listenerConfiguration, - rabbitMqService: provider.GetRequiredService(), - loggerFactory: provider.GetRequiredService(), - timeProvider: provider.GetRequiredService() - ) - ) - ); - - return services; - } - - private static bool IsAzureServiceBusEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(settings.EventLogging.AzureServiceBus.EventTopicName); - } - - private static bool IsRabbitMqEnabled(GlobalSettings settings) - { - return CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.HostName) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Username) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.Password) && - CoreHelpers.SettingHasValue(settings.EventLogging.RabbitMq.EventExchangeName); + // Add security requirement + config.AddSecurityRequirement(new OpenApiSecurityRequirement + { + { + new OpenApiSecurityScheme + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = serverId + }, + }, + [ApiScopes.ApiOrganization] + } + }); } } diff --git a/src/Sql/Sql.sqlproj b/src/Sql/Sql.sqlproj index 1a7530321e..0622c5cbb2 100644 --- a/src/Sql/Sql.sqlproj +++ b/src/Sql/Sql.sqlproj @@ -17,7 +17,4 @@ 71502 - - - diff --git a/src/Sql/dbo/Stored Procedures/Event_Create.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_Create.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_Create.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_Create.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadById.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadById.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadById.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadById.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByCipherId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByCipherId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByCipherId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByCipherId.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByOrganizationId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByOrganizationId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByOrganizationId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByOrganizationId.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByOrganizationIdActingUserId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByOrganizationIdActingUserId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByOrganizationIdActingUserId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByOrganizationIdActingUserId.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByProviderId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByProviderId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByProviderId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByProviderId.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByProviderIdActingUserId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByProviderIdActingUserId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByProviderIdActingUserId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByProviderIdActingUserId.sql diff --git a/src/Sql/dbo/Stored Procedures/Event_ReadPageByUserId.sql b/src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByUserId.sql similarity index 100% rename from src/Sql/dbo/Stored Procedures/Event_ReadPageByUserId.sql rename to src/Sql/dbo/Dirt/Stored Procedures/Event_ReadPageByUserId.sql diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql index d6cd206558..397911549c 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql @@ -6,7 +6,19 @@ CREATE PROCEDURE [dbo].[OrganizationReport_Create] @ContentEncryptionKey VARCHAR(MAX), @SummaryData NVARCHAR(MAX), @ApplicationData NVARCHAR(MAX), - @RevisionDate DATETIME2(7) + @RevisionDate DATETIME2(7), + @ApplicationCount INT = NULL, + @ApplicationAtRiskCount INT = NULL, + @CriticalApplicationCount INT = NULL, + @CriticalApplicationAtRiskCount INT = NULL, + @MemberCount INT = NULL, + @MemberAtRiskCount INT = NULL, + @CriticalMemberCount INT = NULL, + @CriticalMemberAtRiskCount INT = NULL, + @PasswordCount INT = NULL, + @PasswordAtRiskCount INT = NULL, + @CriticalPasswordCount INT = NULL, + @CriticalPasswordAtRiskCount INT = NULL AS BEGIN SET NOCOUNT ON; @@ -20,7 +32,19 @@ INSERT INTO [dbo].[OrganizationReport]( [ContentEncryptionKey], [SummaryData], [ApplicationData], - [RevisionDate] + [RevisionDate], + [ApplicationCount], + [ApplicationAtRiskCount], + [CriticalApplicationCount], + [CriticalApplicationAtRiskCount], + [MemberCount], + [MemberAtRiskCount], + [CriticalMemberCount], + [CriticalMemberAtRiskCount], + [PasswordCount], + [PasswordAtRiskCount], + [CriticalPasswordCount], + [CriticalPasswordAtRiskCount] ) VALUES ( @Id, @@ -30,6 +54,18 @@ VALUES ( @ContentEncryptionKey, @SummaryData, @ApplicationData, - @RevisionDate + @RevisionDate, + @ApplicationCount, + @ApplicationAtRiskCount, + @CriticalApplicationCount, + @CriticalApplicationAtRiskCount, + @MemberCount, + @MemberAtRiskCount, + @CriticalMemberCount, + @CriticalMemberAtRiskCount, + @PasswordCount, + @PasswordAtRiskCount, + @CriticalPasswordCount, + @CriticalPasswordAtRiskCount ); END diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql index 1312369fa8..fca8788ce6 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql @@ -5,14 +5,7 @@ BEGIN SET NOCOUNT ON SELECT TOP 1 - [Id], - [OrganizationId], - [ReportData], - [CreationDate], - [ContentEncryptionKey], - [SummaryData], - [ApplicationData], - [RevisionDate] + * FROM [dbo].[OrganizationReportView] WHERE [OrganizationId] = @OrganizationId ORDER BY [RevisionDate] DESC diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql index 4732fb8ef4..e78d25267d 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql @@ -6,7 +6,19 @@ CREATE PROCEDURE [dbo].[OrganizationReport_Update] @ContentEncryptionKey VARCHAR(MAX), @SummaryData NVARCHAR(MAX), @ApplicationData NVARCHAR(MAX), - @RevisionDate DATETIME2(7) + @RevisionDate DATETIME2(7), + @ApplicationCount INT = NULL, + @ApplicationAtRiskCount INT = NULL, + @CriticalApplicationCount INT = NULL, + @CriticalApplicationAtRiskCount INT = NULL, + @MemberCount INT = NULL, + @MemberAtRiskCount INT = NULL, + @CriticalMemberCount INT = NULL, + @CriticalMemberAtRiskCount INT = NULL, + @PasswordCount INT = NULL, + @PasswordAtRiskCount INT = NULL, + @CriticalPasswordCount INT = NULL, + @CriticalPasswordAtRiskCount INT = NULL AS BEGIN SET NOCOUNT ON; @@ -18,6 +30,18 @@ BEGIN [ContentEncryptionKey] = @ContentEncryptionKey, [SummaryData] = @SummaryData, [ApplicationData] = @ApplicationData, - [RevisionDate] = @RevisionDate + [RevisionDate] = @RevisionDate, + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount WHERE [Id] = @Id; END; diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql new file mode 100644 index 0000000000..8b06c90fe1 --- /dev/null +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql @@ -0,0 +1,39 @@ +CREATE PROCEDURE [dbo].[OrganizationReport_UpdateMetrics] + @Id UNIQUEIDENTIFIER, + @ApplicationCount INT, + @ApplicationAtRiskCount INT, + @CriticalApplicationCount INT, + @CriticalApplicationAtRiskCount INT, + @MemberCount INT, + @MemberAtRiskCount INT, + @CriticalMemberCount INT, + @CriticalMemberAtRiskCount INT, + @PasswordCount INT, + @PasswordAtRiskCount INT, + @CriticalPasswordCount INT, + @CriticalPasswordAtRiskCount INT, + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + UPDATE + [dbo].[OrganizationReport] + SET + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount, + [RevisionDate] = @RevisionDate + WHERE + [Id] = @Id + +END diff --git a/src/Sql/dbo/Tables/Event.sql b/src/Sql/dbo/Dirt/Tables/Event.sql similarity index 100% rename from src/Sql/dbo/Tables/Event.sql rename to src/Sql/dbo/Dirt/Tables/Event.sql diff --git a/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql b/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql index 4c47eafad8..2ffedc3f41 100644 --- a/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql +++ b/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql @@ -1,12 +1,24 @@ CREATE TABLE [dbo].[OrganizationReport] ( - [Id] UNIQUEIDENTIFIER NOT NULL, - [OrganizationId] UNIQUEIDENTIFIER NOT NULL, - [ReportData] NVARCHAR(MAX) NOT NULL, - [CreationDate] DATETIME2 (7) NOT NULL, - [ContentEncryptionKey] VARCHAR(MAX) NOT NULL, - [SummaryData] NVARCHAR(MAX) NULL, - [ApplicationData] NVARCHAR(MAX) NULL, - [RevisionDate] DATETIME2 (7) NULL, + [Id] UNIQUEIDENTIFIER NOT NULL, + [OrganizationId] UNIQUEIDENTIFIER NOT NULL, + [ReportData] NVARCHAR(MAX) NOT NULL, + [CreationDate] DATETIME2 (7) NOT NULL, + [ContentEncryptionKey] VARCHAR(MAX) NOT NULL, + [SummaryData] NVARCHAR(MAX) NULL, + [ApplicationData] NVARCHAR(MAX) NULL, + [RevisionDate] DATETIME2 (7) NULL, + [ApplicationCount] INT NULL, + [ApplicationAtRiskCount] INT NULL, + [CriticalApplicationCount] INT NULL, + [CriticalApplicationAtRiskCount] INT NULL, + [MemberCount] INT NULL, + [MemberAtRiskCount] INT NULL, + [CriticalMemberCount] INT NULL, + [CriticalMemberAtRiskCount] INT NULL, + [PasswordCount] INT NULL, + [PasswordAtRiskCount] INT NULL, + [CriticalPasswordCount] INT NULL, + [CriticalPasswordAtRiskCount] INT NULL, CONSTRAINT [PK_OrganizationReport] PRIMARY KEY CLUSTERED ([Id] ASC), CONSTRAINT [FK_OrganizationReport_Organization] FOREIGN KEY ([OrganizationId]) REFERENCES [dbo].[Organization] ([Id]) ); diff --git a/src/Sql/dbo/Views/EventView.sql b/src/Sql/dbo/Dirt/Views/EventView.sql similarity index 100% rename from src/Sql/dbo/Views/EventView.sql rename to src/Sql/dbo/Dirt/Views/EventView.sql diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql new file mode 100644 index 0000000000..8bfa0156af --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_ReadByUserId] + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + SELECT + * + FROM + [dbo].[UserSignatureKeyPairView] + WHERE + [UserId] = @UserId; +END diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql new file mode 100644 index 0000000000..6ee33e2a40 --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql @@ -0,0 +1,33 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_SetForRotation] + @Id UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER, + @SignatureAlgorithm TINYINT, + @SigningKey VARCHAR(MAX), + @VerifyingKey VARCHAR(MAX), + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + INSERT INTO [dbo].[UserSignatureKeyPair] + ( + [Id], + [UserId], + [SignatureAlgorithm], + [SigningKey], + [VerifyingKey], + [CreationDate], + [RevisionDate] + ) + VALUES + ( + @Id, + @UserId, + @SignatureAlgorithm, + @SigningKey, + @VerifyingKey, + @CreationDate, + @RevisionDate + ) +END diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql new file mode 100644 index 0000000000..4f673019fc --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql @@ -0,0 +1,19 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_UpdateForRotation] + @UserId UNIQUEIDENTIFIER, + @SignatureAlgorithm TINYINT, + @SigningKey VARCHAR(MAX), + @VerifyingKey VARCHAR(MAX), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + UPDATE + [dbo].[UserSignatureKeyPair] + SET + [SignatureAlgorithm] = @SignatureAlgorithm, + [SigningKey] = @SigningKey, + [VerifyingKey] = @VerifyingKey, + [RevisionDate] = @RevisionDate + WHERE + [UserId] = @UserId; +END diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/User_UpdateKeyConnectorUserKey.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/User_UpdateKeyConnectorUserKey.sql new file mode 100644 index 0000000000..7ab20a42af --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/User_UpdateKeyConnectorUserKey.sql @@ -0,0 +1,28 @@ +CREATE PROCEDURE [dbo].[User_UpdateKeyConnectorUserKey] + @Id UNIQUEIDENTIFIER, + @Key VARCHAR(MAX), + @Kdf TINYINT, + @KdfIterations INT, + @KdfMemory INT, + @KdfParallelism INT, + @UsesKeyConnector BIT, + @RevisionDate DATETIME2(7), + @AccountRevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[User] + SET + [Key] = @Key, + [Kdf] = @Kdf, + [KdfIterations] = @KdfIterations, + [KdfMemory] = @KdfMemory, + [KdfParallelism] = @KdfParallelism, + [UsesKeyConnector] = @UsesKeyConnector, + [RevisionDate] = @RevisionDate, + [AccountRevisionDate] = @AccountRevisionDate + WHERE + [Id] = @Id +END diff --git a/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql b/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql new file mode 100644 index 0000000000..94d4e48a0b --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql @@ -0,0 +1,16 @@ +CREATE TABLE [dbo].[UserSignatureKeyPair] ( + [Id] UNIQUEIDENTIFIER NOT NULL, + [UserId] UNIQUEIDENTIFIER NOT NULL, + [SignatureAlgorithm] TINYINT NOT NULL, + [SigningKey] VARCHAR(MAX) NOT NULL, + [VerifyingKey] VARCHAR(MAX) NOT NULL, + [CreationDate] DATETIME2 (7) NOT NULL, + [RevisionDate] DATETIME2 (7) NOT NULL, + CONSTRAINT [PK_UserSignatureKeyPair] PRIMARY KEY CLUSTERED ([Id] ASC), + CONSTRAINT [FK_UserSignatureKeyPair_User] FOREIGN KEY ([UserId]) REFERENCES [dbo].[User] ([Id]) ON DELETE CASCADE +); +GO + +CREATE UNIQUE NONCLUSTERED INDEX [IX_UserSignatureKeyPair_UserId] + ON [dbo].[UserSignatureKeyPair]([UserId] ASC); +GO diff --git a/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql b/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql new file mode 100644 index 0000000000..959305a3e7 --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql @@ -0,0 +1,6 @@ +CREATE VIEW [dbo].[UserSignatureKeyPairView] +AS +SELECT + * +FROM + [dbo].[UserSignatureKeyPair] diff --git a/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql b/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql new file mode 100644 index 0000000000..31ab443f56 --- /dev/null +++ b/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql @@ -0,0 +1,27 @@ +CREATE TABLE [dbo].[SecretVersion] ( + [Id] UNIQUEIDENTIFIER NOT NULL, + [SecretId] UNIQUEIDENTIFIER NOT NULL, + [Value] NVARCHAR (MAX) NOT NULL, + [VersionDate] DATETIME2 (7) NOT NULL, + [EditorServiceAccountId] UNIQUEIDENTIFIER NULL, + [EditorOrganizationUserId] UNIQUEIDENTIFIER NULL, + CONSTRAINT [PK_SecretVersion] PRIMARY KEY CLUSTERED ([Id] ASC), + CONSTRAINT [FK_SecretVersion_OrganizationUser] FOREIGN KEY ([EditorOrganizationUserId]) REFERENCES [dbo].[OrganizationUser] ([Id]) ON DELETE SET NULL, + CONSTRAINT [FK_SecretVersion_Secret] FOREIGN KEY ([SecretId]) REFERENCES [dbo].[Secret] ([Id]) ON DELETE CASCADE, + CONSTRAINT [FK_SecretVersion_ServiceAccount] FOREIGN KEY ([EditorServiceAccountId]) REFERENCES [dbo].[ServiceAccount] ([Id]) ON DELETE SET NULL +); + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_SecretId] + ON [dbo].[SecretVersion]([SecretId] ASC); + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_EditorServiceAccountId] + ON [dbo].[SecretVersion]([EditorServiceAccountId] ASC) + WHERE [EditorServiceAccountId] IS NOT NULL; + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_EditorOrganizationUserId] + ON [dbo].[SecretVersion]([EditorOrganizationUserId] ASC) + WHERE [EditorOrganizationUserId] IS NOT NULL; +GO \ No newline at end of file diff --git a/src/Sql/dbo/Stored Procedures/Collection_UpdateWithGroups.sql b/src/Sql/dbo/Stored Procedures/Collection_UpdateWithGroups.sql new file mode 100644 index 0000000000..7f7fc2e0d7 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/Collection_UpdateWithGroups.sql @@ -0,0 +1,74 @@ +CREATE PROCEDURE [dbo].[Collection_UpdateWithGroups] + @Id UNIQUEIDENTIFIER, + @OrganizationId UNIQUEIDENTIFIER, + @Name VARCHAR(MAX), + @ExternalId NVARCHAR(300), + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7), + @Groups AS [dbo].[CollectionAccessSelectionType] READONLY, + @DefaultUserCollectionEmail NVARCHAR(256) = NULL, + @Type TINYINT = 0 +AS +BEGIN + SET NOCOUNT ON + + EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type + + -- Groups + -- Delete groups that are no longer in source + DELETE + cg + FROM + [dbo].[CollectionGroup] cg + LEFT JOIN + @Groups g ON cg.GroupId = g.Id + WHERE + cg.CollectionId = @Id + AND g.Id IS NULL; + + -- Update existing groups + UPDATE + cg + SET + cg.ReadOnly = g.ReadOnly, + cg.HidePasswords = g.HidePasswords, + cg.Manage = g.Manage + FROM + [dbo].[CollectionGroup] cg + INNER JOIN + @Groups g ON cg.GroupId = g.Id + WHERE + cg.CollectionId = @Id + AND ( + cg.ReadOnly != g.ReadOnly + OR cg.HidePasswords != g.HidePasswords + OR cg.Manage != g.Manage + ); + + -- Insert new groups + INSERT INTO [dbo].[CollectionGroup] + ( + [CollectionId], + [GroupId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + @Id, + g.Id, + g.ReadOnly, + g.HidePasswords, + g.Manage + FROM + @Groups g + INNER JOIN + [dbo].[Group] grp ON grp.Id = g.Id + LEFT JOIN + [dbo].[CollectionGroup] cg ON cg.CollectionId = @Id AND cg.GroupId = g.Id + WHERE + grp.OrganizationId = @OrganizationId + AND cg.CollectionId IS NULL; + + EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId +END diff --git a/src/Sql/dbo/Stored Procedures/Collection_UpdateWithUsers.sql b/src/Sql/dbo/Stored Procedures/Collection_UpdateWithUsers.sql new file mode 100644 index 0000000000..60fccc51d5 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/Collection_UpdateWithUsers.sql @@ -0,0 +1,74 @@ +CREATE PROCEDURE [dbo].[Collection_UpdateWithUsers] + @Id UNIQUEIDENTIFIER, + @OrganizationId UNIQUEIDENTIFIER, + @Name VARCHAR(MAX), + @ExternalId NVARCHAR(300), + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7), + @Users AS [dbo].[CollectionAccessSelectionType] READONLY, + @DefaultUserCollectionEmail NVARCHAR(256) = NULL, + @Type TINYINT = 0 +AS +BEGIN + SET NOCOUNT ON + + EXEC [dbo].[Collection_Update] @Id, @OrganizationId, @Name, @ExternalId, @CreationDate, @RevisionDate, @DefaultUserCollectionEmail, @Type + + -- Users + -- Delete users that are no longer in source + DELETE + cu + FROM + [dbo].[CollectionUser] cu + LEFT JOIN + @Users u ON cu.OrganizationUserId = u.Id + WHERE + cu.CollectionId = @Id + AND u.Id IS NULL; + + -- Update existing users + UPDATE + cu + SET + cu.ReadOnly = u.ReadOnly, + cu.HidePasswords = u.HidePasswords, + cu.Manage = u.Manage + FROM + [dbo].[CollectionUser] cu + INNER JOIN + @Users u ON cu.OrganizationUserId = u.Id + WHERE + cu.CollectionId = @Id + AND ( + cu.ReadOnly != u.ReadOnly + OR cu.HidePasswords != u.HidePasswords + OR cu.Manage != u.Manage + ); + + -- Insert new users + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + @Id, + u.Id, + u.ReadOnly, + u.HidePasswords, + u.Manage + FROM + @Users u + INNER JOIN + [dbo].[OrganizationUser] ou ON ou.Id = u.Id + LEFT JOIN + [dbo].[CollectionUser] cu ON cu.CollectionId = @Id AND cu.OrganizationUserId = u.Id + WHERE + ou.OrganizationId = @OrganizationId + AND cu.CollectionId IS NULL; + + EXEC [dbo].[User_BumpAccountRevisionDateByCollectionId] @Id, @OrganizationId +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql b/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql new file mode 100644 index 0000000000..bfa9d932c5 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationDomain_HasVerifiedDomainWithBlockPolicy.sql @@ -0,0 +1,34 @@ +CREATE PROCEDURE [dbo].[OrganizationDomain_HasVerifiedDomainWithBlockPolicy] + @DomainName NVARCHAR(255), + @ExcludeOrganizationId UNIQUEIDENTIFIER = NULL +AS +BEGIN + SET NOCOUNT ON + + -- Check if any organization has a verified domain matching the domain name + -- with the BlockClaimedDomainAccountCreation policy enabled (Type = 19) + -- If @ExcludeOrganizationId is provided, exclude that organization from the check + IF EXISTS ( + SELECT 1 + FROM [dbo].[OrganizationDomain] OD + INNER JOIN [dbo].[Organization] O + ON OD.OrganizationId = O.Id + INNER JOIN [dbo].[Policy] P + ON O.Id = P.OrganizationId + WHERE OD.DomainName = @DomainName + AND OD.VerifiedDate IS NOT NULL + AND O.Enabled = 1 + AND O.UsePolicies = 1 + AND O.UseOrganizationDomains = 1 + AND (@ExcludeOrganizationId IS NULL OR O.Id != @ExcludeOrganizationId) + AND P.Type = 19 -- BlockClaimedDomainAccountCreation + AND P.Enabled = 1 + ) + BEGIN + SELECT CAST(1 AS BIT) AS HasBlockPolicy + END + ELSE + BEGIN + SELECT CAST(0 AS BIT) AS HasBlockPolicy + END +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql b/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql index 3240402916..7124be73fb 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationIntegrationConfigurationDetails_ReadManyByEventTypeOrganizationIdIntegrationType.sql @@ -11,7 +11,7 @@ BEGIN FROM [dbo].[OrganizationIntegrationConfigurationDetailsView] oic WHERE - oic.[EventType] = @EventType + (oic.[EventType] = @EventType OR oic.[EventType] IS NULL) AND oic.[OrganizationId] = @OrganizationId AND diff --git a/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql b/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql new file mode 100644 index 0000000000..8e2102772b --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE [dbo].[OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId] + @TenantId NVARCHAR(200), + @TeamId NVARCHAR(200) +AS +BEGIN + SET NOCOUNT ON; + +SELECT TOP 1 * +FROM [dbo].[OrganizationIntegrationView] + CROSS APPLY OPENJSON([Configuration], '$.Teams') + WITH ( TeamId NVARCHAR(MAX) '$.id' ) t +WHERE [Type] = 7 + AND JSON_VALUE([Configuration], '$.TenantId') = @TenantId + AND t.TeamId = @TeamId + AND JSON_VALUE([Configuration], '$.ChannelId') IS NULL + AND JSON_VALUE([Configuration], '$.ServiceUrl') IS NULL; +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql new file mode 100644 index 0000000000..6113664b76 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadByOrganizationIdUserId.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE [dbo].[OrganizationUserUserDetails_ReadByOrganizationIdUserId] + @OrganizationId UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + +SELECT + * +FROM + [dbo].[OrganizationUserUserDetailsView] +WHERE + [OrganizationId] = @OrganizationId +AND + [UserId] = @UserId +END +GO diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql new file mode 100644 index 0000000000..7a1cd78a51 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql @@ -0,0 +1,30 @@ +CREATE PROCEDURE [dbo].[OrganizationUser_ConfirmById] + @Id UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER, + @RevisionDate DATETIME2(7), + @Key NVARCHAR(MAX) = NULL +AS +BEGIN + SET NOCOUNT ON + + DECLARE @RowCount INT; + + UPDATE + [dbo].[OrganizationUser] + SET + [Status] = 2, -- Set to Confirmed + [RevisionDate] = @RevisionDate, + [Key] = @Key + WHERE + [Id] = @Id + AND [Status] = 1 -- Only update if status is Accepted + + SET @RowCount = @@ROWCOUNT; + + IF @RowCount > 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDate] @UserId + END + + SELECT @RowCount; +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql index c2bc690a27..105170cd27 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql @@ -4,31 +4,70 @@ AS BEGIN SET NOCOUNT ON -SELECT - OU.[Id] AS OrganizationUserId, - P.[OrganizationId], - P.[Type] AS PolicyType, - P.[Enabled] AS PolicyEnabled, - P.[Data] AS PolicyData, - OU.[Type] AS OrganizationUserType, - OU.[Status] AS OrganizationUserStatus, - OU.[Permissions] AS OrganizationUserPermissionsData, - CASE WHEN EXISTS ( - SELECT 1 - FROM [dbo].[ProviderUserView] PU - INNER JOIN [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId] - WHERE PU.[UserId] = OU.[UserId] AND PO.[OrganizationId] = P.[OrganizationId] - ) THEN 1 ELSE 0 END AS IsProvider -FROM [dbo].[PolicyView] P -INNER JOIN [dbo].[OrganizationUserView] OU - ON P.[OrganizationId] = OU.[OrganizationId] -WHERE P.[Type] = @PolicyType AND + + DECLARE @UserEmail NVARCHAR(256) + SELECT @UserEmail = Email + FROM + [dbo].[UserView] + WHERE + Id = @UserId + + ;WITH OrgUsers AS ( - (OU.[Status] != 0 AND OU.[UserId] = @UserId) -- OrgUsers who have accepted their invite and are linked to a UserId - OR EXISTS ( - SELECT 1 - FROM [dbo].[UserView] U - WHERE U.[Id] = @UserId AND OU.[Email] = U.[Email] AND OU.[Status] = 0 -- 'Invited' OrgUsers are not linked to a UserId yet, so we have to look up their email - ) + -- All users except invited (Status <> 0): direct UserId match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] <> 0 + AND OU.[UserId] = @UserId + + UNION ALL + + -- Invited users: email match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] = 0 + AND OU.[Email] = @UserEmail + AND @UserEmail IS NOT NULL + ), + Providers AS + ( + SELECT + OrganizationId + FROM + [dbo].[UserProviderAccessView] + WHERE + UserId = @UserId ) -END \ No newline at end of file + SELECT + OU.[Id] AS [OrganizationUserId], + P.[OrganizationId], + P.[Type] AS [PolicyType], + P.[Enabled] AS [PolicyEnabled], + P.[Data] AS [PolicyData], + OU.[Type] AS [OrganizationUserType], + OU.[Status] AS [OrganizationUserStatus], + OU.[Permissions] AS [OrganizationUserPermissionsData], + CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS [IsProvider] + FROM + [dbo].[PolicyView] P + INNER JOIN + OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId] + LEFT JOIN + Providers PR ON PR.[OrganizationId] = OU.[OrganizationId] + WHERE + P.[Type] = @PolicyType +END diff --git a/src/Sql/dbo/Stored Procedures/Organization_Create.sql b/src/Sql/dbo/Stored Procedures/Organization_Create.sql index 295ebb51a8..4fc4681648 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Create.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Create.sql @@ -58,7 +58,9 @@ CREATE PROCEDURE [dbo].[Organization_Create] @LimitItemDeletion BIT = 0, @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, - @SyncSeats BIT = 0 + @SyncSeats BIT = 0, + @UseAutomaticUserConfirmation BIT = 0, + @UsePhishingBlocker BIT = 0 AS BEGIN SET NOCOUNT ON @@ -124,69 +126,75 @@ BEGIN [LimitItemDeletion], [UseOrganizationDomains], [UseAdminSponsoredFamilies], - [SyncSeats] + [SyncSeats], + [UseAutomaticUserConfirmation], + [UsePhishingBlocker], + [MaxStorageGbIncreased] ) VALUES - ( - @Id, - @Identifier, - @Name, - @BusinessName, - @BusinessAddress1, - @BusinessAddress2, - @BusinessAddress3, - @BusinessCountry, - @BusinessTaxNumber, - @BillingEmail, - @Plan, - @PlanType, - @Seats, - @MaxCollections, - @UsePolicies, - @UseSso, - @UseGroups, - @UseDirectory, - @UseEvents, - @UseTotp, - @Use2fa, - @UseApi, - @UseResetPassword, - @SelfHost, - @UsersGetPremium, - @Storage, - @MaxStorageGb, - @Gateway, - @GatewayCustomerId, - @GatewaySubscriptionId, - @ReferenceData, - @Enabled, - @LicenseKey, - @PublicKey, - @PrivateKey, - @TwoFactorProviders, - @ExpirationDate, - @CreationDate, - @RevisionDate, - @OwnersNotifiedOfAutoscaling, - @MaxAutoscaleSeats, - @UseKeyConnector, - @UseScim, - @UseCustomPermissions, - @UseSecretsManager, - @Status, - @UsePasswordManager, - @SmSeats, - @SmServiceAccounts, - @MaxAutoscaleSmSeats, - @MaxAutoscaleSmServiceAccounts, - @SecretsManagerBeta, - @LimitCollectionCreation, - @LimitCollectionDeletion, - @AllowAdminAccessToAllCollectionItems, - @UseRiskInsights, - @LimitItemDeletion, - @UseOrganizationDomains, - @UseAdminSponsoredFamilies, - @SyncSeats - ) + ( + @Id, + @Identifier, + @Name, + @BusinessName, + @BusinessAddress1, + @BusinessAddress2, + @BusinessAddress3, + @BusinessCountry, + @BusinessTaxNumber, + @BillingEmail, + @Plan, + @PlanType, + @Seats, + @MaxCollections, + @UsePolicies, + @UseSso, + @UseGroups, + @UseDirectory, + @UseEvents, + @UseTotp, + @Use2fa, + @UseApi, + @UseResetPassword, + @SelfHost, + @UsersGetPremium, + @Storage, + @MaxStorageGb, + @Gateway, + @GatewayCustomerId, + @GatewaySubscriptionId, + @ReferenceData, + @Enabled, + @LicenseKey, + @PublicKey, + @PrivateKey, + @TwoFactorProviders, + @ExpirationDate, + @CreationDate, + @RevisionDate, + @OwnersNotifiedOfAutoscaling, + @MaxAutoscaleSeats, + @UseKeyConnector, + @UseScim, + @UseCustomPermissions, + @UseSecretsManager, + @Status, + @UsePasswordManager, + @SmSeats, + @SmServiceAccounts, + @MaxAutoscaleSmSeats, + @MaxAutoscaleSmServiceAccounts, + @SecretsManagerBeta, + @LimitCollectionCreation, + @LimitCollectionDeletion, + @AllowAdminAccessToAllCollectionItems, + @UseRiskInsights, + @LimitItemDeletion, + @UseOrganizationDomains, + @UseAdminSponsoredFamilies, + @SyncSeats, + @UseAutomaticUserConfirmation, + @UsePhishingBlocker, + @MaxStorageGb + ); END diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql index 6a8ed9e0d0..9efefe8d54 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql @@ -27,7 +27,9 @@ BEGIN [UseRiskInsights], [LimitItemDeletion], [UseOrganizationDomains], - [UseAdminSponsoredFamilies] + [UseAdminSponsoredFamilies], + [UseAutomaticUserConfirmation], + [UsePhishingBlocker] FROM [dbo].[Organization] END diff --git a/src/Sql/dbo/Stored Procedures/Organization_Update.sql b/src/Sql/dbo/Stored Procedures/Organization_Update.sql index d60852bab6..946cf03e94 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Update.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Update.sql @@ -58,7 +58,9 @@ CREATE PROCEDURE [dbo].[Organization_Update] @LimitItemDeletion BIT = 0, @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, - @SyncSeats BIT = 0 + @SyncSeats BIT = 0, + @UseAutomaticUserConfirmation BIT = 0, + @UsePhishingBlocker BIT = 0 AS BEGIN SET NOCOUNT ON @@ -124,7 +126,10 @@ BEGIN [LimitItemDeletion] = @LimitItemDeletion, [UseOrganizationDomains] = @UseOrganizationDomains, [UseAdminSponsoredFamilies] = @UseAdminSponsoredFamilies, - [SyncSeats] = @SyncSeats + [SyncSeats] = @SyncSeats, + [UseAutomaticUserConfirmation] = @UseAutomaticUserConfirmation, + [UsePhishingBlocker] = @UsePhishingBlocker, + [MaxStorageGbIncreased] = @MaxStorageGb WHERE - [Id] = @Id + [Id] = @Id; END diff --git a/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql new file mode 100644 index 0000000000..4fe8d153e4 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/ProviderUser_ReadManyByManyUserIds.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[ProviderUser_ReadManyByManyUserIds] + @UserIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + SELECT + [pu].* + FROM + [dbo].[ProviderUserView] AS [pu] + INNER JOIN + @UserIds [u] ON [u].[Id] = [pu].[UserId] +END diff --git a/src/Sql/dbo/Stored Procedures/User_Create.sql b/src/Sql/dbo/Stored Procedures/User_Create.sql index 60d9b5eb32..cf0c12d1c5 100644 --- a/src/Sql/dbo/Stored Procedures/User_Create.sql +++ b/src/Sql/dbo/Stored Procedures/User_Create.sql @@ -41,7 +41,10 @@ @LastKdfChangeDate DATETIME2(7) = NULL, @LastKeyRotationDate DATETIME2(7) = NULL, @LastEmailChangeDate DATETIME2(7) = NULL, - @VerifyDevices BIT = 1 + @VerifyDevices BIT = 1, + @SecurityState VARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignedPublicKey VARCHAR(MAX) = NULL AS BEGIN SET NOCOUNT ON @@ -90,7 +93,11 @@ BEGIN [LastKdfChangeDate], [LastKeyRotationDate], [LastEmailChangeDate], - [VerifyDevices] + [VerifyDevices], + [SecurityState], + [SecurityVersion], + [SignedPublicKey], + [MaxStorageGbIncreased] ) VALUES ( @@ -136,6 +143,10 @@ BEGIN @LastKdfChangeDate, @LastKeyRotationDate, @LastEmailChangeDate, - @VerifyDevices + @VerifyDevices, + @SecurityState, + @SecurityVersion, + @SignedPublicKey, + @MaxStorageGb ) END diff --git a/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql b/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql new file mode 100644 index 0000000000..a4c73c39df --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/User_ReadPremiumAccessByIds.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[User_ReadPremiumAccessByIds] + @Ids [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + SELECT + UPA.[Id], + UPA.[PersonalPremium], + UPA.[OrganizationPremium] + FROM + [dbo].[UserPremiumAccessView] UPA + WHERE + UPA.[Id] IN (SELECT [Id] FROM @Ids) +END diff --git a/src/Sql/dbo/Stored Procedures/User_Update.sql b/src/Sql/dbo/Stored Procedures/User_Update.sql index 15d04d72f6..05e0d4b4de 100644 --- a/src/Sql/dbo/Stored Procedures/User_Update.sql +++ b/src/Sql/dbo/Stored Procedures/User_Update.sql @@ -41,7 +41,10 @@ @LastKdfChangeDate DATETIME2(7) = NULL, @LastKeyRotationDate DATETIME2(7) = NULL, @LastEmailChangeDate DATETIME2(7) = NULL, - @VerifyDevices BIT = 1 + @VerifyDevices BIT = 1, + @SecurityState VARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignedPublicKey VARCHAR(MAX) = NULL AS BEGIN SET NOCOUNT ON @@ -90,7 +93,11 @@ BEGIN [LastKdfChangeDate] = @LastKdfChangeDate, [LastKeyRotationDate] = @LastKeyRotationDate, [LastEmailChangeDate] = @LastEmailChangeDate, - [VerifyDevices] = @VerifyDevices + [VerifyDevices] = @VerifyDevices, + [SecurityState] = @SecurityState, + [SecurityVersion] = @SecurityVersion, + [SignedPublicKey] = @SignedPublicKey, + [MaxStorageGbIncreased] = @MaxStorageGb WHERE [Id] = @Id END diff --git a/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql b/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql new file mode 100644 index 0000000000..8f1fb664ea --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/User_UpdateAccountCryptographicState.sql @@ -0,0 +1,65 @@ +CREATE PROCEDURE [dbo].[User_UpdateAccountCryptographicState] + @Id UNIQUEIDENTIFIER, + @PublicKey NVARCHAR(MAX), + @PrivateKey NVARCHAR(MAX), + @SignedPublicKey NVARCHAR(MAX) = NULL, + @SecurityState NVARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignatureKeyPairId UNIQUEIDENTIFIER = NULL, + @SignatureAlgorithm TINYINT = NULL, + @SigningKey VARCHAR(MAX) = NULL, + @VerifyingKey VARCHAR(MAX) = NULL, + @RevisionDate DATETIME2(7), + @AccountRevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[User] + SET + [PublicKey] = @PublicKey, + [PrivateKey] = @PrivateKey, + [SignedPublicKey] = @SignedPublicKey, + [SecurityState] = @SecurityState, + [SecurityVersion] = @SecurityVersion, + [RevisionDate] = @RevisionDate, + [AccountRevisionDate] = @AccountRevisionDate + WHERE + [Id] = @Id + + IF EXISTS (SELECT 1 FROM [dbo].[UserSignatureKeyPair] WHERE [UserId] = @Id) + BEGIN + UPDATE [dbo].[UserSignatureKeyPair] + SET + [SignatureAlgorithm] = @SignatureAlgorithm, + [SigningKey] = @SigningKey, + [VerifyingKey] = @VerifyingKey, + [RevisionDate] = @RevisionDate + WHERE + [UserId] = @Id + END + ELSE + BEGIN + INSERT INTO [dbo].[UserSignatureKeyPair] + ( + [Id], + [UserId], + [SignatureAlgorithm], + [SigningKey], + [VerifyingKey], + [CreationDate], + [RevisionDate] + ) + VALUES + ( + @SignatureKeyPairId, + @Id, + @SignatureAlgorithm, + @SigningKey, + @VerifyingKey, + @RevisionDate, + @RevisionDate + ) + END +END diff --git a/src/Sql/dbo/Tables/Organization.sql b/src/Sql/dbo/Tables/Organization.sql index 897abef1cf..f07cd4ce0d 100644 --- a/src/Sql/dbo/Tables/Organization.sql +++ b/src/Sql/dbo/Tables/Organization.sql @@ -59,6 +59,9 @@ CREATE TABLE [dbo].[Organization] ( [UseOrganizationDomains] BIT NOT NULL CONSTRAINT [DF_Organization_UseOrganizationDomains] DEFAULT (0), [UseAdminSponsoredFamilies] BIT NOT NULL CONSTRAINT [DF_Organization_UseAdminSponsoredFamilies] DEFAULT (0), [SyncSeats] BIT NOT NULL CONSTRAINT [DF_Organization_SyncSeats] DEFAULT (0), + [UseAutomaticUserConfirmation] BIT NOT NULL CONSTRAINT [DF_Organization_UseAutomaticUserConfirmation] DEFAULT (0), + [MaxStorageGbIncreased] SMALLINT NULL, + [UsePhishingBlocker] BIT NOT NULL CONSTRAINT [DF_Organization_UsePhishingBlocker] DEFAULT (0), CONSTRAINT [PK_Organization] PRIMARY KEY CLUSTERED ([Id] ASC) ); @@ -66,7 +69,7 @@ CREATE TABLE [dbo].[Organization] ( GO CREATE NONCLUSTERED INDEX [IX_Organization_Enabled] ON [dbo].[Organization]([Id] ASC, [Enabled] ASC) - INCLUDE ([UseTotp]); + INCLUDE ([UseTotp], [UsersGetPremium]); GO CREATE UNIQUE NONCLUSTERED INDEX [IX_Organization_Identifier] diff --git a/src/Sql/dbo/Tables/User.sql b/src/Sql/dbo/Tables/User.sql index 239ee67f11..854fe34f4a 100644 --- a/src/Sql/dbo/Tables/User.sql +++ b/src/Sql/dbo/Tables/User.sql @@ -42,6 +42,10 @@ [LastKeyRotationDate] DATETIME2 (7) NULL, [LastEmailChangeDate] DATETIME2 (7) NULL, [VerifyDevices] BIT DEFAULT ((1)) NOT NULL, + [SecurityState] VARCHAR (MAX) NULL, + [SecurityVersion] INT NULL, + [SignedPublicKey] VARCHAR (MAX) NULL, + [MaxStorageGbIncreased] SMALLINT NULL, CONSTRAINT [PK_User] PRIMARY KEY CLUSTERED ([Id] ASC) ); diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql index ee7e00b32a..6082e89efc 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/CipherDetails_CreateWithCollections.sql @@ -30,4 +30,10 @@ BEGIN DECLARE @UpdateCollectionsSuccess INT EXEC @UpdateCollectionsSuccess = [dbo].[Cipher_UpdateCollections] @Id, @UserId, @OrganizationId, @CollectionIds + + -- Bump the account revision date AFTER collections are assigned. + IF @UpdateCollectionsSuccess = 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDateByCipherId] @Id, @OrganizationId + END END diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql index ac7be1bbae..c6816a1226 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql @@ -23,4 +23,10 @@ BEGIN DECLARE @UpdateCollectionsSuccess INT EXEC @UpdateCollectionsSuccess = [dbo].[Cipher_UpdateCollections] @Id, @UserId, @OrganizationId, @CollectionIds + + -- Bump the account revision date AFTER collections are assigned. + IF @UpdateCollectionsSuccess = 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDateByCipherId] @Id, @OrganizationId + END END diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql index 55852c4d27..3fe877c168 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_UpdateWithCollections.sql @@ -38,8 +38,13 @@ BEGIN [Data] = @Data, [Attachments] = @Attachments, [RevisionDate] = @RevisionDate, - [DeletedDate] = @DeletedDate, [Key] = @Key, [ArchivedDate] = @ArchivedDate - -- No need to update CreationDate, Favorites, Folders, or Type since that data will not change + [DeletedDate] = @DeletedDate, + [Key] = @Key, + [ArchivedDate] = @ArchivedDate, + [Folders] = @Folders, + [Favorites] = @Favorites, + [Reprompt] = @Reprompt + -- No need to update CreationDate or Type since that data will not change WHERE [Id] = @Id diff --git a/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql new file mode 100644 index 0000000000..8e00d06e43 --- /dev/null +++ b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds] + @CipherIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[SecurityTask] + SET + [Status] = 1, -- completed + [RevisionDate] = SYSUTCDATETIME() + WHERE + [CipherId] IN (SELECT [Id] FROM @CipherIds) + AND [Status] <> 1 -- Not already completed +END diff --git a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql index ba7e765569..ffd6810b1b 100644 --- a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql @@ -24,7 +24,7 @@ SELECT O.[UseSecretsManager], O.[Seats], O.[MaxCollections], - O.[MaxStorageGb], + COALESCE(O.[MaxStorageGbIncreased], O.[MaxStorageGb]) AS [MaxStorageGb], O.[Identifier], OU.[Key], OU.[ResetPasswordKey], @@ -54,7 +54,9 @@ SELECT O.[LimitItemDeletion], O.[UseAdminSponsoredFamilies], O.[UseOrganizationDomains], - OS.[IsAdminInitiated] + OS.[IsAdminInitiated], + O.[UseAutomaticUserConfirmation], + O.[UsePhishingBlocker] FROM [dbo].[OrganizationUser] OU LEFT JOIN diff --git a/src/Sql/dbo/Views/OrganizationView.sql b/src/Sql/dbo/Views/OrganizationView.sql index 58989273fd..6e42d08338 100644 --- a/src/Sql/dbo/Views/OrganizationView.sql +++ b/src/Sql/dbo/Views/OrganizationView.sql @@ -1,6 +1,67 @@ CREATE VIEW [dbo].[OrganizationView] AS SELECT - * + [Id], + [Identifier], + [Name], + [BusinessName], + [BusinessAddress1], + [BusinessAddress2], + [BusinessAddress3], + [BusinessCountry], + [BusinessTaxNumber], + [BillingEmail], + [Plan], + [PlanType], + [Seats], + [MaxCollections], + [UsePolicies], + [UseSso], + [UseGroups], + [UseDirectory], + [UseEvents], + [UseTotp], + [Use2fa], + [UseApi], + [UseResetPassword], + [SelfHost], + [UsersGetPremium], + [Storage], + COALESCE([MaxStorageGbIncreased], [MaxStorageGb]) AS [MaxStorageGb], + [Gateway], + [GatewayCustomerId], + [GatewaySubscriptionId], + [ReferenceData], + [Enabled], + [LicenseKey], + [PublicKey], + [PrivateKey], + [TwoFactorProviders], + [ExpirationDate], + [CreationDate], + [RevisionDate], + [OwnersNotifiedOfAutoscaling], + [MaxAutoscaleSeats], + [UseKeyConnector], + [UseScim], + [UseCustomPermissions], + [UseSecretsManager], + [Status], + [UsePasswordManager], + [SmSeats], + [SmServiceAccounts], + [MaxAutoscaleSmSeats], + [MaxAutoscaleSmServiceAccounts], + [SecretsManagerBeta], + [LimitCollectionCreation], + [LimitCollectionDeletion], + [LimitItemDeletion], + [AllowAdminAccessToAllCollectionItems], + [UseRiskInsights], + [UseOrganizationDomains], + [UseAdminSponsoredFamilies], + [SyncSeats], + [UseAutomaticUserConfirmation], + [UsePhishingBlocker] FROM [dbo].[Organization] diff --git a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql index bd2485b411..e1d5ef9144 100644 --- a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql @@ -16,12 +16,14 @@ SELECT O.[Use2fa], O.[UseApi], O.[UseResetPassword], + O.[UseSecretsManager], + O.[UsePasswordManager], O.[SelfHost], O.[UsersGetPremium], O.[UseCustomPermissions], O.[Seats], O.[MaxCollections], - O.[MaxStorageGb], + COALESCE(O.[MaxStorageGbIncreased], O.[MaxStorageGb]) AS [MaxStorageGb], O.[Identifier], PO.[Key], O.[PublicKey], @@ -39,7 +41,11 @@ SELECT O.[UseAdminSponsoredFamilies], P.[Type] ProviderType, O.[LimitItemDeletion], - O.[UseOrganizationDomains] + O.[UseOrganizationDomains], + O.[UseAutomaticUserConfirmation], + SS.[Enabled] SsoEnabled, + SS.[Data] SsoConfig, + O.[UsePhishingBlocker] FROM [dbo].[ProviderUser] PU INNER JOIN @@ -48,3 +54,5 @@ INNER JOIN [dbo].[Organization] O ON O.[Id] = PO.[OrganizationId] INNER JOIN [dbo].[Provider] P ON P.[Id] = PU.[ProviderId] +LEFT JOIN + [dbo].[SsoConfig] SS ON SS.[OrganizationId] = O.[Id] diff --git a/src/Sql/dbo/Views/UserPremiumAccessView.sql b/src/Sql/dbo/Views/UserPremiumAccessView.sql new file mode 100644 index 0000000000..a20cab8fb3 --- /dev/null +++ b/src/Sql/dbo/Views/UserPremiumAccessView.sql @@ -0,0 +1,21 @@ +CREATE VIEW [dbo].[UserPremiumAccessView] +AS +SELECT + U.[Id], + U.[Premium] AS [PersonalPremium], + CAST( + MAX(CASE + WHEN O.[Id] IS NOT NULL THEN 1 + ELSE 0 + END) AS BIT + ) AS [OrganizationPremium] +FROM + [dbo].[User] U +LEFT JOIN + [dbo].[OrganizationUser] OU ON OU.[UserId] = U.[Id] +LEFT JOIN + [dbo].[Organization] O ON O.[Id] = OU.[OrganizationId] + AND O.[UsersGetPremium] = 1 + AND O.[Enabled] = 1 +GROUP BY + U.[Id], U.[Premium]; diff --git a/src/Sql/dbo/Views/UserProviderAccessView.sql b/src/Sql/dbo/Views/UserProviderAccessView.sql new file mode 100644 index 0000000000..dedc380311 --- /dev/null +++ b/src/Sql/dbo/Views/UserProviderAccessView.sql @@ -0,0 +1,9 @@ +CREATE VIEW [dbo].[UserProviderAccessView] +AS +SELECT DISTINCT + PU.[UserId], + PO.[OrganizationId] +FROM + [dbo].[ProviderUserView] PU +INNER JOIN + [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId] diff --git a/src/Sql/dbo/Views/UserView.sql b/src/Sql/dbo/Views/UserView.sql index 82fa8a2c63..fa8dbf334b 100644 --- a/src/Sql/dbo/Views/UserView.sql +++ b/src/Sql/dbo/Views/UserView.sql @@ -1,6 +1,51 @@ CREATE VIEW [dbo].[UserView] AS SELECT - * + [Id], + [Name], + [Email], + [EmailVerified], + [MasterPassword], + [MasterPasswordHint], + [Culture], + [SecurityStamp], + [TwoFactorProviders], + [TwoFactorRecoveryCode], + [EquivalentDomains], + [ExcludedGlobalEquivalentDomains], + [AccountRevisionDate], + [Key], + [PublicKey], + [PrivateKey], + [Premium], + [PremiumExpirationDate], + [RenewalReminderDate], + [Storage], + COALESCE([MaxStorageGbIncreased], [MaxStorageGb]) AS [MaxStorageGb], + [Gateway], + [GatewayCustomerId], + [GatewaySubscriptionId], + [ReferenceData], + [LicenseKey], + [ApiKey], + [Kdf], + [KdfIterations], + [KdfMemory], + [KdfParallelism], + [CreationDate], + [RevisionDate], + [ForcePasswordReset], + [UsesKeyConnector], + [FailedLoginCount], + [LastFailedLoginDate], + [AvatarColor], + [LastPasswordChangeDate], + [LastKdfChangeDate], + [LastKeyRotationDate], + [LastEmailChangeDate], + [VerifyDevices], + [SecurityState], + [SecurityVersion], + [SignedPublicKey] FROM [dbo].[User] diff --git a/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index 44ad5088cd..84ef5c7f3d 100644 --- a/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -1,5 +1,7 @@ using Bit.Admin.AdminConsole.Controllers; using Bit.Admin.AdminConsole.Models; +using Bit.Admin.Enums; +using Bit.Admin.Services; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -276,5 +278,40 @@ public class OrganizationsControllerTests await providerBillingService.Received(1).ScaleSeats(provider, update.PlanType!.Value, update.Seats!.Value - organization.Seats.Value + organization.Seats.Value); } + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task Edit_UseAutomaticUserConfirmation_FullUpdate_SavesFeatureCorrectly( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var update = new OrganizationEditModel + { + PlanType = PlanType.TeamsMonthly, + UseAutomaticUserConfirmation = true + }; + + organization.UseAutomaticUserConfirmation = false; + + sutProvider.GetDependency() + .UserHasPermission(Permission.Org_Plan_Edit) + .Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + + // Act + _ = await sutProvider.Sut.Edit(organization.Id, update); + + // Assert + await organizationRepository.Received(1).ReplaceAsync(Arg.Is(o => o.Id == organization.Id + && o.UseAutomaticUserConfirmation == true)); + + // Annul + await organizationRepository.DeleteAsync(organization); + } + #endregion } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs new file mode 100644 index 0000000000..71c6bf104c --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs @@ -0,0 +1,63 @@ +using System.Net; +using System.Text; +using System.Text.Json; +using Bit.Api.AdminConsole.Models.Request; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Request; +using Bit.Seeder.Recipes; +using Xunit; +using Xunit.Abstractions; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class GroupsControllerPerformanceTests(ITestOutputHelper testOutputHelper) +{ + /// + /// Tests PUT /organizations/{orgId}/groups/{id} + /// + [Theory(Skip = "Performance test")] + [InlineData(10, 5)] + //[InlineData(100, 10)] + //[InlineData(1000, 20)] + public async Task UpdateGroup_WithUsersAndCollections(int userCount, int collectionCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + var groupIds = groupsSeeder.AddToOrganization(orgId, 1, orgUserIds, 0); + + var groupId = groupIds.First(); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var updateRequest = new GroupRequestModel + { + Name = "Updated Group Name", + Collections = collectionIds.Select(c => new SelectionReadOnlyRequestModel { Id = c, ReadOnly = false, HidePasswords = false, Manage = false }), + Users = orgUserIds + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(updateRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/groups/{groupId}", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /organizations/{{orgId}}/groups/{{id}} - Users: {orgUserIds.Count}; Collections: {collectionIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs new file mode 100644 index 0000000000..8df1fcaf2b --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerAutoConfirmTests.cs @@ -0,0 +1,225 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUserControllerAutoConfirmTests : IClassFixture, IAsyncLifetime +{ + private const string _mockEncryptedString = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private string _ownerEmail = null!; + + public OrganizationUserControllerAutoConfirmTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-owner-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + } + + [Fact] + public async Task AutoConfirm_WhenUserCannotManageOtherUsers_ThenShouldReturnForbidden() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + var (confirmingUserEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, organization.Id, OrganizationUserType.User); + await _loginHelper.LoginAsync(confirmingUserEmail); + + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions { ManageUsers = false }, + OrganizationUserStatusType.Accepted); + + var result = await _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + }); + + Assert.Equal(HttpStatusCode.Forbidden, result.StatusCode); + + await _factory.GetService().DeleteAsync(organization); + } + + [Fact] + public async Task AutoConfirm_WhenOwnerConfirmsValidUser_ThenShouldReturnNoContent() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.AutomaticUserConfirmation, + Enabled = true + }); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.OrganizationDataOwnership, + Enabled = true + }); + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + await _loginHelper.LoginAsync(_ownerEmail); + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions(), + OrganizationUserStatusType.Accepted); + + var result = await _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + }); + + Assert.Equal(HttpStatusCode.NoContent, result.StatusCode); + + var orgUserRepository = _factory.GetService(); + var confirmedUser = await orgUserRepository.GetByIdAsync(organizationUser.Id); + Assert.NotNull(confirmedUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, confirmedUser.Status); + Assert.Equal(testKey, confirmedUser.Key); + + var collectionRepository = _factory.GetService(); + var collections = await collectionRepository.GetManyByUserIdAsync(organizationUser.UserId!.Value); + + Assert.NotEmpty(collections); + Assert.Single(collections.Where(c => c.Type == CollectionType.DefaultUserCollection)); + + await _factory.GetService().DeleteAsync(organization); + } + + [Fact] + public async Task AutoConfirm_WhenUserIsConfirmedMultipleTimes_ThenShouldSuccessAndOnlyConfirmOneUser() + { + var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + organization.UseAutomaticUserConfirmation = true; + + await _factory.GetService() + .UpsertAsync(organization); + + var testKey = $"test-key-{Guid.NewGuid()}"; + + var userToConfirmEmail = $"org-user-to-confirm-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(userToConfirmEmail); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.AutomaticUserConfirmation, + Enabled = true + }); + + await _factory.GetService().CreateAsync(new Policy + { + OrganizationId = organization.Id, + Type = PolicyType.OrganizationDataOwnership, + Enabled = true + }); + + await _loginHelper.LoginAsync(_ownerEmail); + + var organizationUser = await OrganizationTestHelpers.CreateUserAsync( + _factory, + organization.Id, + userToConfirmEmail, + OrganizationUserType.User, + false, + new Permissions(), + OrganizationUserStatusType.Accepted); + + var tenRequests = Enumerable.Range(0, 10) + .Select(_ => _client.PostAsJsonAsync($"organizations/{organization.Id}/users/{organizationUser.Id}/auto-confirm", + new OrganizationUserConfirmRequestModel + { + Key = testKey, + DefaultUserCollectionName = _mockEncryptedString + })).ToList(); + + var results = await Task.WhenAll(tenRequests); + + Assert.Contains(results, r => r.StatusCode == HttpStatusCode.NoContent); + + var orgUserRepository = _factory.GetService(); + var confirmedUser = await orgUserRepository.GetByIdAsync(organizationUser.Id); + Assert.NotNull(confirmedUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, confirmedUser.Status); + Assert.Equal(testKey, confirmedUser.Key); + + var collections = await _factory.GetService() + .GetManyByUserIdAsync(organizationUser.UserId!.Value); + Assert.NotEmpty(collections); + // validates user only received one default collection + Assert.Single(collections.Where(c => c.Type == CollectionType.DefaultUserCollection)); + + await _factory.GetService().DeleteAsync(organization); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs new file mode 100644 index 0000000000..6645f29eae --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerBulkRevokeTests.cs @@ -0,0 +1,347 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Response; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Providers.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUserControllerBulkRevokeTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + + public OrganizationUserControllerBulkRevokeTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.BulkRevokeUsersV2) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-user-bulk-revoke-test-{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseMonthly, + ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task BulkRevoke_Success() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser1) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + var (_, orgUser2) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser1.Id, orgUser2.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Equal(2, content.Data.Count()); + Assert.All(content.Data, r => Assert.Empty(r.Error)); + + var actualUsers = await organizationUserRepository.GetManyAsync([orgUser1.Id, orgUser2.Id]); + Assert.All(actualUsers, u => Assert.Equal(OrganizationUserStatusType.Revoked, u.Status)); + } + + [Fact] + public async Task BulkRevoke_AsAdmin_Success() + { + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(adminEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.All(content.Data, r => Assert.Empty(r.Error)); + + var actualUser = await _factory.GetService().GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_CannotRevokeSelf_ReturnsError() + { + var (userEmail, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(userEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == orgUser.Id && r.Error == "You cannot revoke yourself."); + + var actualUser = await _factory.GetService().GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_AlreadyRevoked_ReturnsError() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + await organizationUserRepository.RevokeAsync(orgUser.Id); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == orgUser.Id && r.Error == "Already revoked."); + + var actualUser = await organizationUserRepository.GetByIdAsync(orgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_AdminCannotRevokeOwner_ReturnsError() + { + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + + await _loginHelper.LoginAsync(adminEmail); + + var (_, ownerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.Owner); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [ownerOrgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Single(content.Data); + Assert.Contains(content.Data, r => r.Id == ownerOrgUser.Id && r.Error == "Only owners can revoke other owners."); + + var actualUser = await _factory.GetService().GetByIdAsync(ownerOrgUser.Id); + Assert.NotNull(actualUser); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUser.Status); + } + + [Fact] + public async Task BulkRevoke_MixedResults() + { + var (ownerEmail, requestingOwner) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, validOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + var (_, alreadyRevokedOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var organizationUserRepository = _factory.GetService(); + + await organizationUserRepository.RevokeAsync(alreadyRevokedOrgUser.Id); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [validOrgUser.Id, alreadyRevokedOrgUser.Id, requestingOwner.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + var content = await httpResponse.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + Assert.NotNull(content); + Assert.Equal(3, content.Data.Count()); + + Assert.Contains(content.Data, r => r.Id == validOrgUser.Id && r.Error == string.Empty); + Assert.Contains(content.Data, r => r.Id == alreadyRevokedOrgUser.Id && r.Error == "Already revoked."); + Assert.Contains(content.Data, r => r.Id == requestingOwner.Id && r.Error == "You cannot revoke yourself."); + + var actualUsers = await organizationUserRepository.GetManyAsync([validOrgUser.Id, alreadyRevokedOrgUser.Id, requestingOwner.Id]); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUsers.First(u => u.Id == validOrgUser.Id).Status); + Assert.Equal(OrganizationUserStatusType.Revoked, actualUsers.First(u => u.Id == alreadyRevokedOrgUser.Id).Status); + Assert.Equal(OrganizationUserStatusType.Confirmed, actualUsers.First(u => u.Id == requestingOwner.Id).Status); + } + + [Theory] + [InlineData(OrganizationUserType.User)] + [InlineData(OrganizationUserType.Custom)] + public async Task BulkRevoke_WithoutManageUsersPermission_ReturnsForbidden(OrganizationUserType organizationUserType) + { + var (userEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, organizationUserType, new Permissions { ManageUsers = false }); + + await _loginHelper.LoginAsync(userEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [Guid.NewGuid()] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.Forbidden, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_WithEmptyIds_ReturnsBadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.BadRequest, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_WithInvalidOrganizationId_ReturnsForbidden() + { + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + await _loginHelper.LoginAsync(ownerEmail); + + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.User); + + var invalidOrgId = Guid.NewGuid(); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [orgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{invalidOrgId}/users/revoke", request); + + Assert.Equal(HttpStatusCode.Forbidden, httpResponse.StatusCode); + } + + [Fact] + public async Task BulkRevoke_ProviderRevokesOwner_ReturnsOk() + { + var providerEmail = $"provider-user{Guid.NewGuid()}@example.com"; + + // create user for provider + await _factory.LoginWithNewAccount(providerEmail); + + // create provider and provider user + await _factory.GetService() + .CreateBusinessUnitAsync( + new Provider + { + Name = "provider", + Type = ProviderType.BusinessUnit + }, + providerEmail, + PlanType.EnterpriseAnnually2023, + 10); + + await _loginHelper.LoginAsync(providerEmail); + + var providerUserUser = await _factory.GetService().GetByEmailAsync(providerEmail); + + var providerUserCollection = await _factory.GetService() + .GetManyByUserAsync(providerUserUser!.Id); + + var providerUser = providerUserCollection.First(); + + await _factory.GetService().CreateAsync(new ProviderOrganization + { + ProviderId = providerUser.ProviderId, + OrganizationId = _organization.Id, + Key = null, + Settings = null + }); + + var (_, ownerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + + var request = new OrganizationUserBulkRequestModel + { + Ids = [ownerOrgUser.Id] + }; + + var httpResponse = await _client.PutAsJsonAsync($"organizations/{_organization.Id}/users/revoke", request); + + Assert.Equal(HttpStatusCode.OK, httpResponse.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs index 7c61a88bd8..0fef4a0cd0 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUserControllerTests.cs @@ -218,7 +218,7 @@ public class OrganizationUserControllerTests : IClassFixture + /// Tests GET /organizations/{orgId}/users?includeCollections=true + ///
    [Theory(Skip = "Performance test")] - [InlineData(100)] - [InlineData(60000)] - public async Task GetAsync(int seats) + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task GetAllUsers_WithCollections(int seats) { await using var factory = new SqlServerApiApplicationFactory(); var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var seeder = new OrganizationWithUsersRecipe(db); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); - var orgId = seeder.Seed("Org", seats, "large.test"); + var domain = OrganizationTestHelpers.GenerateRandomDomain(); - var tokens = await factory.LoginAsync("admin@large.test", "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="); - client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); + groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); var stopwatch = System.Diagnostics.Stopwatch.StartNew(); var response = await client.GetAsync($"/organizations/{orgId}/users?includeCollections=true"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); - var result = await response.Content.ReadAsStringAsync(); - Assert.NotEmpty(result); + stopwatch.Stop(); + testOutputHelper.WriteLine($"GET /users - Seats: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + } + + /// + /// Tests GET /organizations/{orgId}/users/mini-details + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task GetAllUsers_MiniDetails(int seats) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); + groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/mini-details"); stopwatch.Stop(); - testOutputHelper.WriteLine($"Seed: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + testOutputHelper.WriteLine($"GET /users/mini-details - Seats: {seats}; Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests GET /organizations/{orgId}/users/{id}?includeGroups=true + /// + [Fact(Skip = "Performance test")] + public async Task GetSingleUser_WithGroups() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserId = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).FirstOrDefault(); + groupsSeeder.AddToOrganization(orgId, 2, [orgUserId]); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/{orgUserId}?includeGroups=true"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"GET /users/{{id}} - Request duration: {stopwatch.ElapsedMilliseconds} ms"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests GET /organizations/{orgId}/users/{id}/reset-password-details + /// + [Fact(Skip = "Performance test")] + public async Task GetResetPasswordDetails_ForSingleUser() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserId = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).FirstOrDefault(); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.GetAsync($"/organizations/{orgId}/users/{orgUserId}/reset-password-details"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"GET /users/{{id}}/reset-password-details - Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/confirm + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkConfirmUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Accepted); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var acceptedUserIds = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Status == OrganizationUserStatusType.Accepted) + .Select(ou => ou.Id) + .ToList(); + + var confirmRequest = new OrganizationUserBulkConfirmRequestModel + { + Keys = acceptedUserIds.Select(id => new OrganizationUserBulkConfirmRequestModelEntry { Id = id, Key = "test-key-" + id }), + DefaultUserCollectionName = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=" + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(confirmRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/confirm", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/confirm - Users: {acceptedUserIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/remove + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRemoveUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRemove = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var removeRequest = new OrganizationUserBulkRequestModel { Ids = usersToRemove }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var requestContent = new StringContent(JsonSerializer.Serialize(removeRequest), Encoding.UTF8, "application/json"); + + var response = await client.PostAsync($"/organizations/{orgId}/users/remove", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/remove - Users: {usersToRemove.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/revoke + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRevokeUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Confirmed); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRevoke = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var revokeRequest = new OrganizationUserBulkRequestModel { Ids = usersToRevoke }; + + var requestContent = new StringContent(JsonSerializer.Serialize(revokeRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/revoke", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/revoke - Users: {usersToRevoke.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/restore + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkRestoreUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Revoked); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToRestore = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var restoreRequest = new OrganizationUserBulkRequestModel { Ids = usersToRestore }; + + var requestContent = new StringContent(JsonSerializer.Serialize(restoreRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/restore", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/restore - Users: {usersToRestore.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/delete-account + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkDeleteAccounts(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var domainSeeder = new OrganizationDomainRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Confirmed); + + domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToDelete = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var deleteRequest = new OrganizationUserBulkRequestModel { Ids = usersToDelete }; + + var requestContent = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/delete-account", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/delete-account - Users: {usersToDelete.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/{id} + /// + [Fact(Skip = "Performance test")] + public async Task UpdateSingleUser_WithCollectionsAndGroups() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, 3, orgUserIds, 0); + var groupIds = groupsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var userToUpdate = db.OrganizationUsers + .FirstOrDefault(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User); + + var updateRequest = new OrganizationUserUpdateRequestModel + { + Type = OrganizationUserType.Custom, + Collections = collectionIds.Select(c => new SelectionReadOnlyRequestModel { Id = c, ReadOnly = false, HidePasswords = false, Manage = false }), + Groups = groupIds, + AccessSecretsManager = false, + Permissions = new Permissions { AccessEventLogs = true } + }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/{userToUpdate.Id}", + new StringContent(JsonSerializer.Serialize(updateRequest), Encoding.UTF8, "application/json")); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/{{id}} - Collections: {collectionIds.Count}; Groups: {groupIds.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests PUT /organizations/{orgId}/users/enable-secrets-manager + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkEnableSecretsManager(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToEnable = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User) + .Select(ou => ou.Id) + .ToList(); + + var enableRequest = new OrganizationUserBulkRequestModel { Ids = usersToEnable }; + + var requestContent = new StringContent(JsonSerializer.Serialize(enableRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PutAsync($"/organizations/{orgId}/users/enable-secrets-manager", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"PUT /users/enable-secrets-manager - Users: {usersToEnable.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); + } + + /// + /// Tests DELETE /organizations/{orgId}/users/{id}/delete-account + /// + [Fact(Skip = "Performance test")] + public async Task DeleteSingleUserAccount_FromVerifiedDomain() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var domainSeeder = new OrganizationDomainRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: 2, + usersStatus: OrganizationUserStatusType.Confirmed); + + domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var userToDelete = db.OrganizationUsers + .FirstOrDefault(ou => ou.OrganizationId == orgId && ou.Type == OrganizationUserType.User); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.DeleteAsync($"/organizations/{orgId}/users/{userToDelete.Id}/delete-account"); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"DELETE /users/{{id}}/delete-account - Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/invite + /// + [Theory(Skip = "Performance test")] + [InlineData(1)] + //[InlineData(5)] + //[InlineData(20)] + public async Task InviteUsers(int emailCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + var collectionIds = collectionsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var emails = Enumerable.Range(0, emailCount).Select(i => $"{i:D4}@{domain}").ToArray(); + var inviteRequest = new OrganizationUserInviteRequestModel + { + Emails = emails, + Type = OrganizationUserType.User, + AccessSecretsManager = false, + Collections = Array.Empty(), + Groups = Array.Empty(), + Permissions = null + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(inviteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/invite", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/invite - Emails: {emails.Length}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{orgId}/users/reinvite + /// + [Theory(Skip = "Performance test")] + [InlineData(10)] + //[InlineData(100)] + //[InlineData(1000)] + public async Task BulkReinviteUsers(int userCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed( + name: "Org", + domain: domain, + users: userCount, + usersStatus: OrganizationUserStatusType.Invited); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var usersToReinvite = db.OrganizationUsers + .Where(ou => ou.OrganizationId == orgId && ou.Status == OrganizationUserStatusType.Invited) + .Select(ou => ou.Id) + .ToList(); + + var reinviteRequest = new OrganizationUserBulkRequestModel { Ids = usersToReinvite }; + + var requestContent = new StringContent(JsonSerializer.Serialize(reinviteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/users/reinvite", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /users/reinvite - Users: {usersToReinvite.Count}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.True(response.IsSuccessStatusCode); } } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs new file mode 100644 index 0000000000..38e3cac863 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs @@ -0,0 +1,188 @@ +using System.Net; +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Request.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + + public OrganizationUsersControllerPutResetPasswordTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"reset-password-test-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + // Enable reset password and policies for the organization + var organizationRepository = _factory.GetService(); + _organization.UseResetPassword = true; + _organization.UsePolicies = true; + await organizationRepository.ReplaceAsync(_organization); + + // Enable the ResetPassword policy + var policyRepository = _factory.GetService(); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = _organization.Id, + Type = PolicyType.ResetPassword, + Enabled = true, + Data = "{}" + }); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + /// + /// Helper method to set the ResetPasswordKey on an organization user, which is required for account recovery + /// + private async Task SetResetPasswordKeyAsync(OrganizationUser orgUser) + { + var organizationUserRepository = _factory.GetService(); + orgUser.ResetPasswordKey = "encrypted-reset-password-key"; + await organizationUserRepository.ReplaceAsync(orgUser); + } + + [Fact] + public async Task PutResetPassword_AsHigherRole_CanRecoverLowerRole() + { + // Arrange + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + var (_, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutResetPassword_AsLowerRole_CannotRecoverHigherRole() + { + // Arrange + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + await _loginHelper.LoginAsync(adminEmail); + + var (_, targetOwnerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.Owner); + await SetResetPasswordKeyAsync(targetOwnerOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOwnerOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Contains(RecoverAccountAuthorizationHandler.FailureReason, model.Message); + } + + [Fact] + public async Task PutResetPassword_CannotRecoverProviderAccount() + { + // Arrange - Create owner who will try to recover the provider account + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + // Create a user who is also a provider user + var (targetUserEmail, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + // Add the target user as a provider user to a different provider + var providerRepository = _factory.GetService(); + var providerUserRepository = _factory.GetService(); + var userRepository = _factory.GetService(); + + var provider = await providerRepository.CreateAsync(new Provider + { + Name = "Test Provider", + BusinessName = "Test Provider Business", + BillingEmail = "provider@example.com", + Type = ProviderType.Msp, + Status = ProviderStatusType.Created, + Enabled = true + }); + + var targetUser = await userRepository.GetByEmailAsync(targetUserEmail); + Assert.NotNull(targetUser); + + await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = provider.Id, + UserId = targetUser.Id, + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin + }); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Equal(RecoverAccountAuthorizationHandler.ProviderFailureReason, model.Message); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs new file mode 100644 index 0000000000..238a9a5d53 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs @@ -0,0 +1,163 @@ +using System.Net; +using System.Text; +using System.Text.Json; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.Auth.Models.Request.Accounts; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core.AdminConsole.Models.Business.Tokenables; +using Bit.Core.Billing.Enums; +using Bit.Core.Tokens; +using Bit.Seeder.Recipes; +using Xunit; +using Xunit.Abstractions; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationsControllerPerformanceTests(ITestOutputHelper testOutputHelper) +{ + /// + /// Tests DELETE /organizations/{id} with password verification + /// + [Theory(Skip = "Performance test")] + [InlineData(10, 5, 3)] + //[InlineData(100, 20, 10)] + //[InlineData(1000, 50, 25)] + public async Task DeleteOrganization_WithPasswordVerification(int userCount, int collectionCount, int groupCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var deleteRequest = new SecretVerificationRequestModel + { + MasterPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs=" + }; + + var request = new HttpRequestMessage(HttpMethod.Delete, $"/organizations/{orgId}") + { + Content = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json") + }; + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + + var response = await client.SendAsync(request); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"DELETE /organizations/{{id}} - Users: {userCount}; Collections: {collectionCount}; Groups: {groupCount}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/{id}/delete-recover-token with token verification + /// + [Theory(Skip = "Performance test")] + [InlineData(10, 5, 3)] + //[InlineData(100, 20, 10)] + //[InlineData(1000, 50, 25)] + public async Task DeleteOrganization_WithTokenVerification(int userCount, int collectionCount, int groupCount) + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var db = factory.GetDatabaseContext(); + var orgSeeder = new OrganizationWithUsersRecipe(db); + var collectionsSeeder = new CollectionsRecipe(db); + var groupsSeeder = new GroupsRecipe(db); + + var domain = OrganizationTestHelpers.GenerateRandomDomain(); + var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); + + var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); + collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); + + var organization = db.Organizations.FirstOrDefault(o => o.Id == orgId); + Assert.NotNull(organization); + + var tokenFactory = factory.GetService>(); + var tokenable = new OrgDeleteTokenable(organization, 24); + var token = tokenFactory.Protect(tokenable); + + var deleteRequest = new OrganizationVerifyDeleteRecoverRequestModel + { + Token = token + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(deleteRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync($"/organizations/{orgId}/delete-recover-token", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /organizations/{{id}}/delete-recover-token - Users: {userCount}; Collections: {collectionCount}; Groups: {groupCount}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + /// + /// Tests POST /organizations/create-without-payment + /// + [Fact(Skip = "Performance test")] + public async Task CreateOrganization_WithoutPayment() + { + await using var factory = new SqlServerApiApplicationFactory(); + var client = factory.CreateClient(); + + var email = $"user@{OrganizationTestHelpers.GenerateRandomDomain()}"; + var masterPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="; + + await factory.LoginWithNewAccount(email, masterPasswordHash); + + await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, email, masterPasswordHash); + + var createRequest = new OrganizationNoPaymentCreateRequest + { + Name = "Test Organization", + BusinessName = "Test Business Name", + BillingEmail = email, + PlanType = PlanType.EnterpriseAnnually, + Key = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=", + AdditionalSeats = 1, + AdditionalStorageGb = 1, + UseSecretsManager = true, + AdditionalSmSeats = 1, + AdditionalServiceAccounts = 2, + MaxAutoscaleSeats = 100, + PremiumAccessAddon = false, + CollectionName = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk=" + }; + + var requestContent = new StringContent(JsonSerializer.Serialize(createRequest), Encoding.UTF8, "application/json"); + + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var response = await client.PostAsync("/organizations/create-without-payment", requestContent); + + stopwatch.Stop(); + + testOutputHelper.WriteLine($"POST /organizations/create-without-payment - AdditionalSeats: {createRequest.AdditionalSeats}; AdditionalStorageGb: {createRequest.AdditionalStorageGb}; AdditionalSmSeats: {createRequest.AdditionalSmSeats}; AdditionalServiceAccounts: {createRequest.AdditionalServiceAccounts}; MaxAutoscaleSeats: {createRequest.MaxAutoscaleSeats}; Request duration: {stopwatch.ElapsedMilliseconds} ms; Status: {response.StatusCode}"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs new file mode 100644 index 0000000000..c234e77bc8 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -0,0 +1,196 @@ +using System.Net; +using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationsControllerTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + private readonly string _billingEmail = "billing@example.com"; + private readonly string _organizationName = "Organizations Controller Test Org"; + + public OrganizationsControllerTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"org-integration-test-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, + name: _organizationName, + billingEmail: _billingEmail, + plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, + passwordManagerSeats: 5, + paymentMethod: PaymentMethodType.Card); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task Put_AsOwner_WithoutProvider_CanUpdateOrganization() + { + // Arrange - Regular organization owner (no provider) + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal("newbillingemail@example.com", updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsProvider_CanUpdateOrganization() + { + // Create and login as a new account to be the provider user (not the owner) + var providerUserEmail = $"provider-{Guid.NewGuid()}@example.com"; + var (token, _) = await _factory.LoginWithNewAccount(providerUserEmail); + + // Set up provider linked to org and ProviderUser entry + var provider = await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, + ProviderType.Msp); + await ProviderTestHelpers.CreateProviderUserAsync(_factory, provider.Id, providerUserEmail, + ProviderUserType.ProviderAdmin); + + await _loginHelper.LoginAsync(providerUserEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal("newbillingemail@example.com", updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_NotMemberOrProvider_CannotUpdateOrganization() + { + // Create and login as a new account to be unrelated to the org + var userEmail = "stranger@example.com"; + await _factory.LoginWithNewAccount(userEmail); + await _loginHelper.LoginAsync(userEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "newbillingemail@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + + // Verify the organization name was not updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal(_organizationName, updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsOwner_WithProvider_CanRenameOrganization() + { + // Arrange - Create provider and link to organization + // The active user is ONLY an org owner, NOT a provider user + await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, ProviderType.Msp); + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = null + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify the organization name was actually updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal("Updated Organization Name", updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } + + [Fact] + public async Task Put_AsOwner_WithProvider_CannotChangeBillingEmail() + { + // Arrange - Create provider and link to organization + // The active user is ONLY an org owner, NOT a provider user + await ProviderTestHelpers.CreateProviderAndLinkToOrganizationAsync(_factory, _organization.Id, ProviderType.Msp); + await _loginHelper.LoginAsync(_ownerEmail); + + var updateRequest = new OrganizationUpdateRequestModel + { + Name = "Updated Organization Name", + BillingEmail = "updatedbilling@example.com" + }; + + // Act + var response = await _client.PutAsJsonAsync($"/organizations/{_organization.Id}", updateRequest); + + // Assert + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + + // Verify the organization was not updated + var organizationRepository = _factory.GetService(); + var updatedOrg = await organizationRepository.GetByIdAsync(_organization.Id); + Assert.NotNull(updatedOrg); + Assert.Equal(_organizationName, updatedOrg.Name); + Assert.Equal(_billingEmail, updatedOrg.BillingEmail); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs index 1efc2f843d..e4098ce9a9 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs @@ -67,7 +67,6 @@ public class PoliciesControllerTests : IClassFixture, IAs { Policy = new PolicyRequestModel { - Type = policyType, Enabled = true, }, Metadata = new Dictionary @@ -148,7 +147,6 @@ public class PoliciesControllerTests : IClassFixture, IAs { Policy = new PolicyRequestModel { - Type = policyType, Enabled = true, Data = new Dictionary { @@ -211,4 +209,192 @@ public class PoliciesControllerTests : IClassFixture, IAs } } + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minLength", content); // Verify field name is in error message + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minComplexity", "not a number" }, // Wrong type - should be int + { "minLength", 12 } + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minComplexity", content); // Verify field name is in error message + } + + [Fact] + public async Task PutVNext_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.SingleOrg; + var request = new PolicyRequestModel + { + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutVNext_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.TwoFactorAuthentication; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = null + }, + Metadata = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs b/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs index 32c7f75a2b..6ba65f6453 100644 --- a/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Import/ImportOrganizationUsersAndGroupsCommandTests.cs @@ -33,7 +33,7 @@ public class ImportOrganizationUsersAndGroupsCommandTests : IClassFixture, IAsy await _factory.LoginWithNewAccount(_ownerEmail); // Create the organization - (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); // Authorize with the organization api key @@ -64,6 +64,17 @@ public class MembersControllerTests : IClassFixture, IAsy var (userEmail4, orgUser4) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.Admin); + var collection1 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 1", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = false, Manage = true }, + new CollectionAccessSelection { Id = orgUser3.Id, ReadOnly = true, HidePasswords = false, Manage = false } + ]); + + var collection2 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 2", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = true, Manage = false } + ]); + var response = await _client.GetAsync($"/public/members"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); var result = await response.Content.ReadFromJsonAsync>(); @@ -71,23 +82,47 @@ public class MembersControllerTests : IClassFixture, IAsy Assert.Equal(5, result.Data.Count()); // The owner - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner)); + var ownerResult = result.Data.SingleOrDefault(m => m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner); + Assert.NotNull(ownerResult); + Assert.Empty(ownerResult.Collections); - // The custom user + // The custom user with collections var user1Result = result.Data.Single(m => m.Email == userEmail1); Assert.Equal(OrganizationUserType.Custom, user1Result.Type); AssertHelper.AssertPropertyEqual( new PermissionsModel { AccessImportExport = true, ManagePolicies = true, AccessReports = true }, user1Result.Permissions); + // Verify collections + Assert.NotNull(user1Result.Collections); + Assert.Equal(2, user1Result.Collections.Count()); + var user1Collection1 = user1Result.Collections.Single(c => c.Id == collection1.Id); + Assert.False(user1Collection1.ReadOnly); + Assert.False(user1Collection1.HidePasswords); + Assert.True(user1Collection1.Manage); + var user1Collection2 = user1Result.Collections.Single(c => c.Id == collection2.Id); + Assert.False(user1Collection2.ReadOnly); + Assert.True(user1Collection2.HidePasswords); + Assert.False(user1Collection2.Manage); - // Everyone else - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail2 && m.Type == OrganizationUserType.Owner)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail3 && m.Type == OrganizationUserType.User)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail4 && m.Type == OrganizationUserType.Admin)); + // The other owner + var user2Result = result.Data.SingleOrDefault(m => m.Email == userEmail2 && m.Type == OrganizationUserType.Owner); + Assert.NotNull(user2Result); + Assert.Empty(user2Result.Collections); + + // The user with one collection + var user3Result = result.Data.SingleOrDefault(m => m.Email == userEmail3 && m.Type == OrganizationUserType.User); + Assert.NotNull(user3Result); + Assert.NotNull(user3Result.Collections); + Assert.Single(user3Result.Collections); + var user3Collection1 = user3Result.Collections.Single(c => c.Id == collection1.Id); + Assert.True(user3Collection1.ReadOnly); + Assert.False(user3Collection1.HidePasswords); + Assert.False(user3Collection1.Manage); + + // The admin with no collections + var user4Result = result.Data.SingleOrDefault(m => m.Email == userEmail4 && m.Type == OrganizationUserType.Admin); + Assert.NotNull(user4Result); + Assert.Empty(user4Result.Collections); } [Fact] diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index f034426f98..6144d7eebb 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -39,7 +39,7 @@ public class PoliciesControllerTests : IClassFixture, IAs await _factory.LoginWithNewAccount(_ownerEmail); // Create the organization - (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); // Authorize with the organization api key @@ -160,4 +160,86 @@ public class PoliciesControllerTests : IClassFixture, IAs Assert.Equal(15, data.MinLength); Assert.Equal(true, data.RequireUpper); } + + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.DisableSend; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs index 4e5a6850e7..09ec5b010f 100644 --- a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs +++ b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs @@ -1,31 +1,81 @@ -using System.Net.Http.Headers; +using System.Net; +using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.KeyManagement.Models.Requests; using Bit.Api.Models.Response; +using Bit.Core; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Identity; +using NSubstitute; using Xunit; namespace Bit.Api.IntegrationTest.Controllers; -public class AccountsControllerTest : IClassFixture +public class AccountsControllerTest : IClassFixture, IAsyncLifetime { - private readonly ApiApplicationFactory _factory; + private static readonly string _masterKeyWrappedUserKey = + "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; - public AccountsControllerTest(ApiApplicationFactory factory) => _factory = factory; + private static readonly string _masterPasswordHash = "master_password_hash"; + private static readonly string _newMasterPasswordHash = "new_master_password_hash"; + + private static readonly KdfRequestModel _defaultKdfRequest = + new() { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_000 }; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + private readonly IUserRepository _userRepository; + private readonly IPushNotificationService _pushNotificationService; + private readonly IFeatureService _featureService; + private readonly IPasswordHasher _passwordHasher; + + private string _ownerEmail = null!; + + public AccountsControllerTest(ApiApplicationFactory factory) + { + _factory = factory; + _factory.SubstituteService(_ => { }); + _factory.SubstituteService(_ => { }); + _client = factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + _userRepository = _factory.GetService(); + _pushNotificationService = _factory.GetService(); + _featureService = _factory.GetService(); + _passwordHasher = _factory.GetService>(); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } [Fact] public async Task GetAccountsProfile_success() { - var tokens = await _factory.LoginWithNewAccount(); - var client = _factory.CreateClient(); + await _loginHelper.LoginAsync(_ownerEmail); using var message = new HttpRequestMessage(HttpMethod.Get, "/accounts/profile"); - message.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - var response = await client.SendAsync(message); + var response = await _client.SendAsync(message); response.EnsureSuccessStatusCode(); var content = await response.Content.ReadFromJsonAsync(); Assert.NotNull(content); - Assert.Equal("integration-test@bitwarden.com", content.Email); + Assert.Equal(_ownerEmail, content.Email); Assert.NotNull(content.Name); Assert.True(content.EmailVerified); Assert.False(content.Premium); @@ -35,4 +85,354 @@ public class AccountsControllerTest : IClassFixture Assert.NotNull(content.PrivateKey); Assert.NotNull(content.SecurityStamp); } + + [Theory] + [BitAutoData(KdfType.PBKDF2_SHA256, 600001, null, null)] + [BitAutoData(KdfType.Argon2id, 4, 65, 5)] + public async Task PostKdf_ValidRequestLogoutOnKdfChangeFeatureFlagOff_SuccessLogout(KdfType kdf, + int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + var userBeforeKdfChange = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(userBeforeKdfChange); + + _featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange).Returns(false); + + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism, + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Validate that the user fields were updated correctly + var user = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(user); + Assert.Equal(kdfRequest.KdfType, user.Kdf); + Assert.Equal(kdfRequest.Iterations, user.KdfIterations); + Assert.Equal(kdfRequest.Memory, user.KdfMemory); + Assert.Equal(kdfRequest.Parallelism, user.KdfParallelism); + Assert.Equal(_masterKeyWrappedUserKey, user.Key); + Assert.NotNull(user.LastKdfChangeDate); + Assert.True(user.LastKdfChangeDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.RevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.AccountRevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.NotEqual(userBeforeKdfChange.SecurityStamp, user.SecurityStamp); + Assert.Equal(PasswordVerificationResult.Success, + _passwordHasher.VerifyHashedPassword(user, user.MasterPassword!, _newMasterPasswordHash)); + + // Validate push notification + await _pushNotificationService.Received(1).PushLogOutAsync(user.Id); + } + + [Theory] + [BitAutoData(KdfType.PBKDF2_SHA256, 600001, null, null)] + [BitAutoData(KdfType.Argon2id, 4, 65, 5)] + public async Task PostKdf_ValidRequestLogoutOnKdfChangeFeatureFlagOn_SuccessSyncAndLogoutWithReason(KdfType kdf, + int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + var userBeforeKdfChange = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(userBeforeKdfChange); + + _featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange).Returns(true); + + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism, + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Validate that the user fields were updated correctly + var user = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(user); + Assert.Equal(kdfRequest.KdfType, user.Kdf); + Assert.Equal(kdfRequest.Iterations, user.KdfIterations); + Assert.Equal(kdfRequest.Memory, user.KdfMemory); + Assert.Equal(kdfRequest.Parallelism, user.KdfParallelism); + Assert.Equal(_masterKeyWrappedUserKey, user.Key); + Assert.NotNull(user.LastKdfChangeDate); + Assert.True(user.LastKdfChangeDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.RevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.AccountRevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.Equal(userBeforeKdfChange.SecurityStamp, user.SecurityStamp); + Assert.Equal(PasswordVerificationResult.Success, + _passwordHasher.VerifyHashedPassword(user, user.MasterPassword!, _newMasterPasswordHash)); + + // Validate push notification + await _pushNotificationService.Received(1) + .PushLogOutAsync(user.Id, false, PushNotificationLogOutReason.KdfChange); + await _pushNotificationService.Received(1).PushSyncSettingsAsync(user.Id); + } + + [Fact] + public async Task PostKdf_Unauthorized_ReturnsUnauthorized() + { + // Don't call LoginAsync to test unauthorized access + + var response = await PostKdfWithKdfRequestAsync(_defaultKdfRequest); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + + [Theory] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task PostKdf_AuthenticationDataOrUnlockDataNull_BadRequest(bool authenticationDataNull, + bool unlockDataNull) + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = authenticationDataNull + ? null + : new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = unlockDataNull + ? null + : new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("AuthenticationData and UnlockData must be provided.", content); + } + + [Fact] + public async Task PostKdf_InvalidMasterPasswordHash_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = "wrong-master-password-hash", + NewMasterPasswordHash = _newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationData, + UnlockData = unlockData + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + var response = await _client.SendAsync(message); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Incorrect password", content); + } + + [Fact] + public async Task PostKdf_ChangedSaltInAuthenticationData_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = "wrong-salt@bitwarden.com" + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Invalid master password salt.", content); + } + + [Fact] + public async Task PostKdf_ChangedSaltInUnlockData_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = "wrong-salt@bitwarden.com" + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Invalid master password salt.", content); + } + + [Fact] + public async Task PostKdf_KdfNotMatching_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_000 }, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_001 }, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("KDF settings must be equal for authentication and unlock.", content); + } + + [Theory] + [InlineData(KdfType.PBKDF2_SHA256, 1, null, null)] + [InlineData(KdfType.Argon2id, 4, null, 5)] + [InlineData(KdfType.Argon2id, 4, 65, null)] + public async Task PostKdf_InvalidKdf_BadRequest(KdfType kdf, int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("KDF settings are invalid", content); + } + + [Fact] + public async Task PostKdf_InvalidNewMasterPassword_BadRequest() + { + var newMasterPasswordHash = "too-short"; + + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = _masterPasswordHash, + NewMasterPasswordHash = newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationData, + UnlockData = unlockData + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + var response = await _client.SendAsync(message); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Passwords must be at least", content); + } + + private async Task PostKdfWithKdfRequestAsync(KdfRequestModel kdfRequest) + { + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = kdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = kdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + return await PostKdfAsync(authenticationData, unlockData); + } + + private async Task PostKdfAsync( + MasterPasswordAuthenticationDataRequestModel? authenticationDataRequest, + MasterPasswordUnlockDataRequestModel? unlockDataRequest) + { + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = _masterPasswordHash, + NewMasterPasswordHash = _newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationDataRequest, + UnlockData = unlockDataRequest + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + return await _client.SendAsync(message); + } } diff --git a/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs b/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs new file mode 100644 index 0000000000..a729abb849 --- /dev/null +++ b/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs @@ -0,0 +1,117 @@ +using Bit.Api.AdminConsole.Public.Models.Request; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Public.Request; +using Bit.Api.Models.Public.Response; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Xunit; + +namespace Bit.Api.IntegrationTest.Controllers.Public; + +public class CollectionsControllerTests : IClassFixture, IAsyncLifetime +{ + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private string _ownerEmail = null!; + private Organization _organization = null!; + + public CollectionsControllerTests(ApiApplicationFactory factory) + { + _factory = factory; + _factory.SubstituteService(_ => { }); + _factory.SubstituteService(_ => { }); + _client = factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, + plan: PlanType.EnterpriseAnnually, + ownerEmail: _ownerEmail, + passwordManagerSeats: 10, + paymentMethod: PaymentMethodType.Card); + + await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task CreateCollectionWithMultipleUsersAndVariedPermissions_Success() + { + // Arrange + _organization.AllowAdminAccessToAllCollectionItems = true; + await _factory.GetService().UpsertAsync(_organization); + + var groupRepository = _factory.GetService(); + var group = await groupRepository.CreateAsync(new Group + { + OrganizationId = _organization.Id, + Name = "CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success", + ExternalId = $"CollectionControllerTests.CreateCollectionWithMultipleUsersAndVariedPermissions_Success{Guid.NewGuid()}", + }); + + var (_, user) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, + _organization.Id, + OrganizationUserType.User); + + var collection = await OrganizationTestHelpers.CreateCollectionAsync( + _factory, + _organization.Id, + "Shared Collection with a group", + externalId: "shared-collection-with-group", + groups: + [ + new CollectionAccessSelection { Id = group.Id, ReadOnly = false, HidePasswords = false, Manage = true } + ], + users: + [ + new CollectionAccessSelection { Id = user.Id, ReadOnly = false, HidePasswords = false, Manage = true } + ]); + + var getCollectionsResponse = await _client.GetFromJsonAsync>("public/collections"); + var getCollectionResponse = await _client.GetFromJsonAsync($"public/collections/{collection.Id}"); + + var firstCollection = getCollectionsResponse.Data.First(x => x.ExternalId == "shared-collection-with-group"); + + var update = new CollectionUpdateRequestModel + { + ExternalId = firstCollection.ExternalId, + Groups = firstCollection.Groups?.Select(x => new AssociationWithPermissionsRequestModel + { + Id = x.Id, + ReadOnly = x.ReadOnly, + HidePasswords = x.HidePasswords, + Manage = x.Manage + }), + }; + + await _client.PutAsJsonAsync($"public/collections/{firstCollection.Id}", update); + + var result = await _factory.GetService() + .GetByIdWithAccessAsync(firstCollection.Id); + + Assert.NotNull(result); + Assert.NotEmpty(result.Item2.Groups); + Assert.NotEmpty(result.Item2.Users); + } +} diff --git a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs index 3cd73c4b1c..887ef989ce 100644 --- a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs +++ b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs @@ -151,6 +151,30 @@ public static class OrganizationTestHelpers return group; } + /// + /// Creates a collection with optional user and group associations. + /// + public static async Task CreateCollectionAsync( + ApiApplicationFactory factory, + Guid organizationId, + string name, + IEnumerable? users = null, + IEnumerable? groups = null, + string? externalId = null) + { + var collectionRepository = factory.GetService(); + var collection = new Collection + { + OrganizationId = organizationId, + Name = name, + Type = CollectionType.SharedCollection, + ExternalId = externalId + }; + + await collectionRepository.CreateAsync(collection, groups, users); + return collection; + } + /// /// Enables the Organization Data Ownership policy for the specified organization. /// @@ -170,6 +194,15 @@ public static class OrganizationTestHelpers await policyRepository.CreateAsync(policy); } + /// + /// Generates a unique random domain name for testing purposes. + /// + /// A domain string like "a1b2c3d4.com" + public static string GenerateRandomDomain() + { + return $"{Guid.NewGuid().ToString("N").Substring(0, 8)}.com"; + } + /// /// Creates a user account without a Master Password and adds them as a member to the specified organization. /// diff --git a/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs b/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs new file mode 100644 index 0000000000..ca26266dfa --- /dev/null +++ b/test/Api.IntegrationTest/Helpers/PerformanceTestHelpers.cs @@ -0,0 +1,32 @@ +using System.Net.Http.Headers; +using Bit.Api.IntegrationTest.Factories; + +namespace Bit.Api.IntegrationTest.Helpers; + +/// +/// Helper methods for performance tests to reduce code duplication. +/// +public static class PerformanceTestHelpers +{ + /// + /// Standard password hash used across performance tests. + /// + public const string StandardPasswordHash = "c55hlJ/cfdvTd4awTXUqow6X3cOQCfGwn11o3HblnPs="; + + /// + /// Authenticates an HttpClient with a bearer token for the specified user. + /// + /// The application factory to use for login. + /// The HttpClient to authenticate. + /// The user's email address. + /// The user's master password hash. Defaults to StandardPasswordHash. + public static async Task AuthenticateClientAsync( + SqlServerApiApplicationFactory factory, + HttpClient client, + string email, + string? masterPasswordHash = null) + { + var tokens = await factory.LoginAsync(email, masterPasswordHash ?? StandardPasswordHash); + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); + } +} diff --git a/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs b/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs new file mode 100644 index 0000000000..ab52bcd076 --- /dev/null +++ b/test/Api.IntegrationTest/Helpers/ProviderTestHelpers.cs @@ -0,0 +1,77 @@ +using Bit.Api.IntegrationTest.Factories; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Repositories; + +namespace Bit.Api.IntegrationTest.Helpers; + +public static class ProviderTestHelpers +{ + /// + /// Creates a provider and links it to an organization. + /// This does NOT create any provider users. + /// + /// The API application factory + /// The organization ID to link to the provider + /// The type of provider to create + /// The provider status (defaults to Created) + /// The created provider + public static async Task CreateProviderAndLinkToOrganizationAsync( + ApiApplicationFactory factory, + Guid organizationId, + ProviderType providerType, + ProviderStatusType providerStatus = ProviderStatusType.Created) + { + var providerRepository = factory.GetService(); + var providerOrganizationRepository = factory.GetService(); + + // Create the provider + var provider = await providerRepository.CreateAsync(new Provider + { + Name = $"Test {providerType} Provider", + BusinessName = $"Test {providerType} Provider Business", + BillingEmail = $"provider-{providerType.ToString().ToLower()}@example.com", + Type = providerType, + Status = providerStatus, + Enabled = true + }); + + // Link the provider to the organization + await providerOrganizationRepository.CreateAsync(new ProviderOrganization + { + ProviderId = provider.Id, + OrganizationId = organizationId, + Key = "test-provider-key" + }); + + return provider; + } + + /// + /// Creates a providerUser for a provider. + /// + public static async Task CreateProviderUserAsync( + ApiApplicationFactory factory, + Guid providerId, + string userEmail, + ProviderUserType providerUserType) + { + var userRepository = factory.GetService(); + var user = await userRepository.GetByEmailAsync(userEmail); + if (user is null) + { + throw new Exception("No user found in test setup."); + } + + var providerUserRepository = factory.GetService(); + return await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = providerId, + Status = ProviderUserStatusType.Confirmed, + UserId = user.Id, + Key = Guid.NewGuid().ToString(), + Type = providerUserType + }); + } +} diff --git a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index bf27d7f0d1..eddffb6b36 100644 --- a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -3,19 +3,28 @@ using System.Net; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Api.KeyManagement.Models.Requests; +using Bit.Api.KeyManagement.Models.Responses; using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models; using Bit.Api.Vault.Models.Request; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Enums; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Vault.Enums; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Identity; +using NSubstitute; using Xunit; namespace Bit.Api.IntegrationTest.KeyManagement.Controllers; @@ -24,6 +33,8 @@ public class AccountsKeyManagementControllerTests : IClassFixture _passwordHasher; private readonly IOrganizationRepository _organizationRepository; + private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository; private string _ownerEmail = null!; public AccountsKeyManagementControllerTests(ApiApplicationFactory factory) { _factory = factory; - _factory.UpdateConfiguration("globalSettings:launchDarkly:flagValues:pm-12241-private-key-regeneration", - "true"); + _factory.SubstituteService(featureService => + { + featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration, Arg.Any()) + .Returns(true); + }); _client = factory.CreateClient(); _loginHelper = new LoginHelper(_factory, _client); _userRepository = _factory.GetService(); @@ -49,6 +64,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); _passwordHasher = _factory.GetService>(); _organizationRepository = _factory.GetService(); + _userSignatureKeyPairRepository = _factory.GetService(); } public async Task InitializeAsync() @@ -69,8 +85,11 @@ public class AccountsKeyManagementControllerTests : IClassFixture(featureService => + { + featureService.IsEnabled(FeatureFlagKeys.PrivateKeyRegeneration, Arg.Any()) + .Returns(false); + }); var localClient = localFactory.CreateClient(); var localEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; var localLoginHelper = new LoginHelper(localFactory, localClient); @@ -200,6 +219,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); + + Assert.NotNull(result); + Assert.Equal(organization.Name, result.OrganizationName); + } + + private async Task<(string, Organization)> SetupKeyConnectorTestAsync(OrganizationUserStatusType userStatusType, + string organizationSsoIdentifier = "test-sso-identifier") { var (organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, PlanType.EnterpriseAnnually, _ownerEmail, passwordManagerSeats: 10, @@ -289,69 +659,8 @@ public class AccountsKeyManagementControllerTests : IClassFixture, IAsyncLifetime +{ + private readonly string _mockEncryptedString = + "2.3Uk+WNBIoU5xzmVFNcoWzz==|1MsPIYuRfdOHfu/0uY6H2Q==|/98sp4wb6pHP1VTZ9JcNCYgQjEUMFPlqJgCwRk1YXKg="; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly ISecretRepository _secretRepository; + private readonly ISecretVersionRepository _secretVersionRepository; + private readonly IAccessPolicyRepository _accessPolicyRepository; + private readonly LoginHelper _loginHelper; + + private string _email = null!; + private SecretsManagerOrganizationHelper _organizationHelper = null!; + + public SecretVersionsControllerTests(ApiApplicationFactory factory) + { + _factory = factory; + _client = _factory.CreateClient(); + _secretRepository = _factory.GetService(); + _secretVersionRepository = _factory.GetService(); + _accessPolicyRepository = _factory.GetService(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _email = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_email); + _organizationHelper = new SecretsManagerOrganizationHelper(_factory, _email); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Theory] + [InlineData(false, false, false)] + [InlineData(false, false, true)] + [InlineData(false, true, false)] + [InlineData(false, true, true)] + [InlineData(true, false, false)] + [InlineData(true, false, true)] + [InlineData(true, true, false)] + public async Task GetVersionsBySecretId_SmAccessDenied_NotFound(bool useSecrets, bool accessSecrets, bool organizationEnabled) + { + var (org, _) = await _organizationHelper.Initialize(useSecrets, accessSecrets, organizationEnabled); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Theory] + [InlineData(PermissionType.RunAsAdmin)] + [InlineData(PermissionType.RunAsUserWithPermission)] + public async Task GetVersionsBySecretId_Success(PermissionType permissionType) + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + // Create some versions + var version1 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var version2 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + if (permissionType == PermissionType.RunAsUserWithPermission) + { + var (email, orgUser) = await _organizationHelper.CreateNewUser(OrganizationUserType.User, true); + await _loginHelper.LoginAsync(email); + + var accessPolicies = new List + { + new UserSecretAccessPolicy + { + GrantedSecretId = secret.Id, + OrganizationUserId = orgUser.Id, + Read = true, + Write = true + } + }; + await _accessPolicyRepository.CreateManyAsync(accessPolicies); + } + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync>(); + + Assert.NotNull(result); + Assert.Equal(2, result.Data.Count()); + } + + [Fact] + public async Task GetVersionById_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var version = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow + }); + + var response = await _client.GetAsync($"/secret-versions/{version.Id}"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(); + + Assert.NotNull(result); + Assert.Equal(version.Id, result.Id); + Assert.Equal(secret.Id, result.SecretId); + } + + [Fact] + public async Task RestoreVersion_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = "OriginalValue", + Note = _mockEncryptedString + }); + + var version = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "OldValue", + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + var request = new RestoreSecretVersionRequestModel + { + VersionId = version.Id + }; + + var response = await _client.PutAsJsonAsync($"/secrets/{secret.Id}/versions/restore", request); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync(); + + Assert.NotNull(result); + Assert.Equal("OldValue", result.Value); + } + + [Fact] + public async Task BulkDelete_Success() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + var version1 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var version2 = await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = _mockEncryptedString, + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + var ids = new List { version1.Id, version2.Id }; + + var response = await _client.PostAsJsonAsync("/secret-versions/delete", ids); + response.EnsureSuccessStatusCode(); + + var versions = await _secretVersionRepository.GetManyBySecretIdAsync(secret.Id); + Assert.Empty(versions); + } + + [Fact] + public async Task GetVersionsBySecretId_ReturnsOrderedByVersionDate() + { + var (org, _) = await _organizationHelper.Initialize(true, true, true); + await _loginHelper.LoginAsync(_email); + + var secret = await _secretRepository.CreateAsync(new Secret + { + OrganizationId = org.Id, + Key = _mockEncryptedString, + Value = _mockEncryptedString, + Note = _mockEncryptedString + }); + + // Create versions in random order + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version2", + VersionDate = DateTime.UtcNow.AddDays(-1) + }); + + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version3", + VersionDate = DateTime.UtcNow + }); + + await _secretVersionRepository.CreateAsync(new SecretVersion + { + SecretId = secret.Id, + Value = "Version1", + VersionDate = DateTime.UtcNow.AddDays(-2) + }); + + var response = await _client.GetAsync($"/secrets/{secret.Id}/versions"); + response.EnsureSuccessStatusCode(); + + var result = await response.Content.ReadFromJsonAsync>(); + + Assert.NotNull(result); + Assert.Equal(3, result.Data.Count()); + + var versions = result.Data.ToList(); + // Should be ordered by VersionDate descending (newest first) + Assert.Equal("Version3", versions[0].Value); + Assert.Equal("Version2", versions[1].Value); + Assert.Equal("Version1", versions[2].Value); + } +} diff --git a/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs b/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs new file mode 100644 index 0000000000..92efb641f1 --- /dev/null +++ b/test/Api.Test/AdminConsole/Authorization/RecoverAccountAuthorizationHandlerTests.cs @@ -0,0 +1,296 @@ +using System.Security.Claims; +using Bit.Api.AdminConsole.Authorization; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Authorization; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Authorization; + +[SutProviderCustomize] +public class RecoverAccountAuthorizationHandlerTests +{ + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsProvider_TargetUserNotProvider_Authorized( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsNotMemberOrProvider_NotAuthorized( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ can recover the account for a ___ + public static IEnumerable AuthorizedRoleCombinations => new object[][] + { + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(AuthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_RecoverEqualOrLesserRoles_TargetUserNotProvider_Authorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ cannot recover the account for a ___ + public static IEnumerable UnauthorizedRoleCombinations => new object[][] + { + // These roles should fail because you cannot recover a greater role + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true} }, OrganizationUserType.Admin], + + // These roles are never authorized to recover any account + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(UnauthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_InvalidRoles_TargetUserNotProvider_Unauthorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_TargetUserIdIsNull_DoesNotBlock( + SutProvider sutProvider, + OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.UserId = null; + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + // This should shortcut the provider escalation check + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetManyByUserAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsMemberOfAllTargetUserProviders_DoesNotBlock( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(true); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserMissingProviderMembership_Blocks( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + // Not a member of this provider + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(false); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.ProviderFailureReason); + } + + private static void MockOrganizationClaims(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser, + CurrentContextOrganization? currentContextOrganization) + { + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void MockCurrentUserIsProvider(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + sutProvider.GetDependency() + .IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId) + .Returns(true); + } + + private static void MockCurrentUserIsOwner(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = new CurrentContextOrganization + { + Id = targetOrganizationUser.OrganizationId, + Type = OrganizationUserType.Owner + }; + + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void AssertFailed(AuthorizationHandlerContext context, string expectedMessage) + { + Assert.True(context.HasFailed); + var failureReason = Assert.Single(context.FailureReasons); + Assert.Equal(expectedMessage, failureReason.Message); + } +} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs deleted file mode 100644 index 1dd0e86f39..0000000000 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs +++ /dev/null @@ -1,255 +0,0 @@ -using Bit.Api.AdminConsole.Controllers; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Context; -using Bit.Core.Enums; -using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Mvc; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Controllers; - -[ControllerCustomize(typeof(OrganizationIntegrationController))] -[SutProviderCustomize] -public class OrganizationIntegrationControllerTests -{ - private OrganizationIntegrationRequestModel _webhookRequestModel = new OrganizationIntegrationRequestModel() - { - Configuration = null, - Type = IntegrationType.Webhook - }; - - [Theory, BitAutoData] - public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(() => sutProvider.Sut.GetAsync(organizationId)); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationsExist_ReturnsIntegrations( - SutProvider sutProvider, - Guid organizationId, - List integrations) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetManyByOrganizationAsync(organizationId) - .Returns(integrations); - - var result = await sutProvider.Sut.GetAsync(organizationId); - - await sutProvider.GetDependency().Received(1) - .GetManyByOrganizationAsync(organizationId); - - Assert.Equal(integrations.Count, result.Count); - Assert.All(result, r => Assert.IsType(r)); - } - - [Theory, BitAutoData] - public async Task GetAsync_NoIntegrations_ReturnsEmptyList( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetManyByOrganizationAsync(organizationId) - .Returns([]); - - var result = await sutProvider.Sut.GetAsync(organizationId); - - Assert.Empty(result); - } - - [Theory, BitAutoData] - public async Task CreateAsync_Webhook_AllParamsProvided_Succeeds( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(callInfo => callInfo.Arg()); - var response = await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(response); - Assert.Equal(IntegrationType.Webhook, response.Type); - } - - [Theory, BitAutoData] - public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_AllParamsProvided_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegration); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = Guid.NewGuid(); - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - var response = await sutProvider.Sut.UpdateAsync(organizationId, organizationIntegration.Id, _webhookRequestModel); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(organizationIntegration); - Assert.IsType(response); - Assert.Equal(IntegrationType.Webhook, response.Type); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = Guid.NewGuid(); - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync(organizationId, Guid.Empty, _webhookRequestModel)); - } -} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs deleted file mode 100644 index 4ccfa70308..0000000000 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs +++ /dev/null @@ -1,832 +0,0 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Controllers; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Context; -using Bit.Core.Enums; -using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Mvc; -using NSubstitute; -using NSubstitute.ReturnsExtensions; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Controllers; - -[ControllerCustomize(typeof(OrganizationIntegrationConfigurationController))] -[SutProviderCustomize] -public class OrganizationIntegrationsConfigurationControllerTests -{ - [Theory, BitAutoData] - public async Task DeleteAsync_AllParamsProvided_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegrationConfiguration.Id); - await sutProvider.GetDependency().Received(1) - .DeleteAsync(organizationIntegrationConfiguration); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_IntegrationConfigDoesNotBelongToIntegration_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = Guid.Empty; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, organizationIntegration.Id, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.DeleteAsync(organizationId, Guid.Empty, Guid.Empty)); - } - - [Theory, BitAutoData] - public async Task GetAsync_ConfigurationsExist_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - List organizationIntegrationConfigurations) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) - .Returns(organizationIntegrationConfigurations); - - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); - Assert.NotNull(result); - Assert.Equal(organizationIntegrationConfigurations.Count, result.Count); - Assert.All(result, r => Assert.IsType(r)); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); - } - - [Theory, BitAutoData] - public async Task GetAsync_NoConfigurationsExist_ReturnsEmptyList( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - organizationIntegration.OrganizationId = organizationId; - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetManyByIntegrationAsync(Arg.Any()) - .Returns([]); - - var result = await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id); - Assert.NotNull(result); - Assert.Empty(result); - - await sutProvider.GetDependency().Received(1) - .GetByIdAsync(organizationIntegration.Id); - await sutProvider.GetDependency().Received(1) - .GetManyByIntegrationAsync(organizationIntegration.Id); - } - - // [Theory, BitAutoData] - // public async Task GetAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - // SutProvider sutProvider, - // Guid organizationId, - // OrganizationIntegration organizationIntegration) - // { - // organizationIntegration.OrganizationId = organizationId; - // sutProvider.Sut.Url = Substitute.For(); - // sutProvider.GetDependency() - // .OrganizationOwner(organizationId) - // .Returns(true); - // sutProvider.GetDependency() - // .GetByIdAsync(Arg.Any()) - // .Returns(organizationIntegration); - // sutProvider.GetDependency() - // .GetByIdAsync(Arg.Any()) - // .ReturnsNull(); - // - // await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.Empty, Guid.Empty)); - // } - // - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); - } - - [Theory, BitAutoData] - public async Task GetAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, organizationIntegration.Id)); - } - - [Theory, BitAutoData] - public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.NewGuid())); - } - - [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Slack_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_AllParamsProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(organizationIntegrationConfiguration); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); - - await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeCloudBillingSync_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.CloudBillingSync; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationTypeScim_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Scim; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync( - organizationId, - organizationIntegration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, Guid.Empty, new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Slack_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - - [Theory, BitAutoData] - public async Task UpdateAsync_AllParamsProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_OnlyUrlProvided_Webhook_Succeeds( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost")); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - var expected = new OrganizationIntegrationConfigurationResponseModel(model.ToOrganizationIntegrationConfiguration(organizationIntegrationConfiguration)); - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model); - - await sutProvider.GetDependency().Received(1) - .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegration.Type = IntegrationType.Webhook; - var webhookConfig = new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"), Scheme: "Bearer", Token: "AUTH-TOKEN"); - model.Configuration = JsonSerializer.Serialize(webhookConfig); - model.Template = "Template String"; - model.Filters = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .ReturnsNull(); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidConfiguration_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - model.Configuration = null; - model.Template = "Template String"; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_InvalidTemplate_ThrowsBadRequestException( - SutProvider sutProvider, - Guid organizationId, - OrganizationIntegration organizationIntegration, - OrganizationIntegrationConfiguration organizationIntegrationConfiguration, - OrganizationIntegrationConfigurationRequestModel model) - { - organizationIntegration.OrganizationId = organizationId; - organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; - organizationIntegration.Type = IntegrationType.Slack; - var slackConfig = new SlackIntegrationConfiguration(ChannelId: "C123456"); - model.Configuration = JsonSerializer.Serialize(slackConfig); - model.Template = null; - - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegration); - sutProvider.GetDependency() - .GetByIdAsync(Arg.Any()) - .Returns(organizationIntegrationConfiguration); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - organizationIntegration.Id, - organizationIntegrationConfiguration.Id, - model)); - } - - [Theory, BitAutoData] - public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) - { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); - - await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAsync( - organizationId, - Guid.Empty, - Guid.Empty, - new OrganizationIntegrationConfigurationRequestModel())); - } -} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs index e5aa03f067..43f0123a3f 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs @@ -1,21 +1,28 @@ using System.Security.Claims; +using Bit.Api.AdminConsole.Authorization; using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.Models.Request.Organizations; using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Utilities.v2.Results; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; @@ -29,8 +36,11 @@ using Bit.Test.Common.AutoFixture.Attributes; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; +using Microsoft.AspNetCore.Mvc.ModelBinding; using NSubstitute; +using OneOf.Types; using Xunit; namespace Bit.Api.Test.AdminConsole.Controllers; @@ -440,4 +450,349 @@ public class OrganizationUsersControllerTests Assert.Equal("Master Password reset is required, but not provided.", exception.Message); } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WhenOrganizationUserNotFound_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns((OrganizationUser)null); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WhenOrganizationIdMismatch_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = Guid.NewGuid(); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WhenAuthorizationFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Failed()); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WhenRecoverAccountSucceeds_ReturnsOk( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1) + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WhenRecoverAccountFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error message" })); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_UserIdNull_ReturnsUnauthorized( + Guid orgId, + Guid orgUserId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns((Guid?)null); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_UserIdEmpty_ReturnsUnauthorized( + Guid orgId, + Guid orgUserId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(Guid.Empty); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_Success_ReturnsOk( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(new None())); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_NotFoundError_ReturnsNotFound( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(false); + + var notFoundError = new OrganizationNotFound(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(notFoundError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var notFoundResult = Assert.IsType>(result); + Assert.Equal(notFoundError.Message, notFoundResult.Value.Message); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_BadRequestError_ReturnsBadRequest( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + var badRequestError = new UserIsNotAccepted(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(badRequestError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var badRequestResult = Assert.IsType>(result); + Assert.Equal(badRequestError.Message, badRequestResult.Value.Message); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_InternalError_ReturnsProblem( + Guid orgId, + Guid orgUserId, + Guid userId, + OrganizationUserConfirmRequestModel model, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + var internalError = new FailedToWriteToEventLog(); + sutProvider.GetDependency() + .AutomaticallyConfirmOrganizationUserAsync(Arg.Any()) + .Returns(new CommandResult(internalError)); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(orgId, orgUserId, model); + + // Assert + var problemResult = Assert.IsType>(result); + Assert.Equal(StatusCodes.Status500InternalServerError, problemResult.StatusCode); + } + + [Theory] + [BitAutoData] + public async Task BulkReinvite_WhenFeatureFlagEnabled_UsesBulkResendOrganizationInvitesCommand( + Guid organizationId, + OrganizationUserBulkRequestModel bulkRequestModel, + List organizationUsers, + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) + .Returns(true); + + var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); + sutProvider.GetDependency() + .BulkResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids) + .Returns(expectedResults); + + // Act + var response = await sutProvider.Sut.BulkReinvite(organizationId, bulkRequestModel); + + // Assert + Assert.Equal(organizationUsers.Count, response.Data.Count()); + + await sutProvider.GetDependency() + .Received(1) + .BulkResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); + } + + [Theory] + [BitAutoData] + public async Task BulkReinvite_WhenFeatureFlagDisabled_UsesLegacyOrganizationService( + Guid organizationId, + OrganizationUserBulkRequestModel bulkRequestModel, + List organizationUsers, + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) + .Returns(false); + + var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); + sutProvider.GetDependency() + .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids) + .Returns(expectedResults); + + // Act + var response = await sutProvider.Sut.BulkReinvite(organizationId, bulkRequestModel); + + // Assert + Assert.Equal(organizationUsers.Count, response.Data.Count()); + + await sutProvider.GetDependency() + .Received(1) + .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); + } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index 00fd3c3b4e..d87f035a13 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -1,5 +1,4 @@ using System.Security.Claims; -using AutoFixture.Xunit2; using Bit.Api.AdminConsole.Controllers; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Models.Request.Organizations; @@ -8,9 +7,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business; -using Bit.Core.AdminConsole.Models.Business.Tokenables; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationApiKeys.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; @@ -20,7 +16,6 @@ using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; -using Bit.Core.Auth.Services; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; @@ -30,102 +25,24 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Tokens; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -using GlobalSettings = Bit.Core.Settings.GlobalSettings; namespace Bit.Api.Test.AdminConsole.Controllers; -public class OrganizationsControllerTests : IDisposable +[ControllerCustomize(typeof(OrganizationsController))] +[SutProviderCustomize] +public class OrganizationsControllerTests { - private readonly GlobalSettings _globalSettings; - private readonly ICurrentContext _currentContext; - private readonly IOrganizationRepository _organizationRepository; - private readonly IOrganizationService _organizationService; - private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly ISsoConfigService _ssoConfigService; - private readonly IUserService _userService; - private readonly IGetOrganizationApiKeyQuery _getOrganizationApiKeyQuery; - private readonly IRotateOrganizationApiKeyCommand _rotateOrganizationApiKeyCommand; - private readonly IOrganizationApiKeyRepository _organizationApiKeyRepository; - private readonly ICreateOrganizationApiKeyCommand _createOrganizationApiKeyCommand; - private readonly IFeatureService _featureService; - private readonly IProviderRepository _providerRepository; - private readonly IProviderBillingService _providerBillingService; - private readonly IDataProtectorTokenFactory _orgDeleteTokenDataFactory; - private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; - private readonly ICloudOrganizationSignUpCommand _cloudOrganizationSignUpCommand; - private readonly IOrganizationDeleteCommand _organizationDeleteCommand; - private readonly IPolicyRequirementQuery _policyRequirementQuery; - private readonly IPricingClient _pricingClient; - private readonly IOrganizationUpdateKeysCommand _organizationUpdateKeysCommand; - private readonly OrganizationsController _sut; - - public OrganizationsControllerTests() - { - _currentContext = Substitute.For(); - _globalSettings = Substitute.For(); - _organizationRepository = Substitute.For(); - _organizationService = Substitute.For(); - _organizationUserRepository = Substitute.For(); - _policyRepository = Substitute.For(); - _ssoConfigRepository = Substitute.For(); - _ssoConfigService = Substitute.For(); - _getOrganizationApiKeyQuery = Substitute.For(); - _rotateOrganizationApiKeyCommand = Substitute.For(); - _organizationApiKeyRepository = Substitute.For(); - _userService = Substitute.For(); - _createOrganizationApiKeyCommand = Substitute.For(); - _featureService = Substitute.For(); - _providerRepository = Substitute.For(); - _providerBillingService = Substitute.For(); - _orgDeleteTokenDataFactory = Substitute.For>(); - _removeOrganizationUserCommand = Substitute.For(); - _cloudOrganizationSignUpCommand = Substitute.For(); - _organizationDeleteCommand = Substitute.For(); - _policyRequirementQuery = Substitute.For(); - _pricingClient = Substitute.For(); - _organizationUpdateKeysCommand = Substitute.For(); - - _sut = new OrganizationsController( - _organizationRepository, - _organizationUserRepository, - _policyRepository, - _organizationService, - _userService, - _currentContext, - _ssoConfigRepository, - _ssoConfigService, - _getOrganizationApiKeyQuery, - _rotateOrganizationApiKeyCommand, - _createOrganizationApiKeyCommand, - _organizationApiKeyRepository, - _featureService, - _globalSettings, - _providerRepository, - _providerBillingService, - _orgDeleteTokenDataFactory, - _removeOrganizationUserCommand, - _cloudOrganizationSignUpCommand, - _organizationDeleteCommand, - _policyRequirementQuery, - _pricingClient, - _organizationUpdateKeysCommand); - } - - public void Dispose() - { - _sut?.Dispose(); - } - - [Theory, AutoData] + [Theory, BitAutoData] public async Task OrganizationsController_UserCannotLeaveOrganizationThatProvidesKeyConnector( - Guid orgId, User user) + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -140,21 +57,24 @@ public class OrganizationsControllerTests : IDisposable user.UsesKeyConnector = true; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { null }); - var exception = await Assert.ThrowsAsync(() => _sut.Leave(orgId)); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { null }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Leave(orgId)); Assert.Contains("Your organization's Single Sign-On settings prevent you from leaving.", exception.Message); - await _removeOrganizationUserCommand.DidNotReceiveWithAnyArgs().UserLeaveAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().UserLeaveAsync(default, default); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task OrganizationsController_UserCannotLeaveOrganizationThatManagesUser( - Guid orgId, User user) + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -166,27 +86,34 @@ public class OrganizationsControllerTests : IDisposable Enabled = true, OrganizationId = orgId, }; - var foundOrg = new Organization(); - foundOrg.Id = orgId; + var foundOrg = new Organization + { + Id = orgId + }; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { { foundOrg } }); - var exception = await Assert.ThrowsAsync(() => _sut.Leave(orgId)); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List { foundOrg }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Leave(orgId)); Assert.Contains("Claimed user account cannot leave claiming organization. Contact your organization administrator for additional details.", exception.Message); - await _removeOrganizationUserCommand.DidNotReceiveWithAnyArgs().RemoveUserAsync(default, default); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().RemoveUserAsync(default, default); } [Theory] - [InlineAutoData(true, false)] - [InlineAutoData(false, true)] - [InlineAutoData(false, false)] + [BitAutoData(true, false)] + [BitAutoData(false, true)] + [BitAutoData(false, false)] public async Task OrganizationsController_UserCanLeaveOrganizationThatDoesntProvideKeyConnector( - bool keyConnectorEnabled, bool userUsesKeyConnector, Guid orgId, User user) + bool keyConnectorEnabled, + bool userUsesKeyConnector, + SutProvider sutProvider, + Guid orgId, + User user) { var ssoConfig = new SsoConfig { @@ -203,18 +130,19 @@ public class OrganizationsControllerTests : IDisposable user.UsesKeyConnector = userUsesKeyConnector; - _currentContext.OrganizationUser(orgId).Returns(true); - _ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _userService.GetOrganizationsClaimingUserAsync(user.Id).Returns(new List()); + sutProvider.GetDependency().OrganizationUser(orgId).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetOrganizationsClaimingUserAsync(user.Id).Returns(new List()); - await _sut.Leave(orgId); + await sutProvider.Sut.Leave(orgId); - await _removeOrganizationUserCommand.Received(1).UserLeaveAsync(orgId, user.Id); + await sutProvider.GetDependency().Received(1).UserLeaveAsync(orgId, user.Id); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task Delete_OrganizationIsConsolidatedBillingClient_ScalesProvidersSeats( + SutProvider sutProvider, Provider provider, Organization organization, User user, @@ -228,87 +156,89 @@ public class OrganizationsControllerTests : IDisposable provider.Type = ProviderType.Msp; provider.Status = ProviderStatusType.Billable; - _currentContext.OrganizationOwner(organizationId).Returns(true); + sutProvider.GetDependency().OrganizationOwner(organizationId).Returns(true); + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().VerifySecretAsync(user, requestModel.Secret).Returns(true); + sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); - _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + await sutProvider.Sut.Delete(organizationId.ToString(), requestModel); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - - _userService.VerifySecretAsync(user, requestModel.Secret).Returns(true); - - _providerRepository.GetByOrganizationIdAsync(organization.Id).Returns(provider); - - await _sut.Delete(organizationId.ToString(), requestModel); - - await _providerBillingService.Received(1) + await sutProvider.GetDependency().Received(1) .ScaleSeats(provider, organization.PlanType, -organization.Seats.Value); - await _organizationDeleteCommand.Received(1).DeleteAsync(organization); + await sutProvider.GetDependency().Received(1).DeleteAsync(organization); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task GetAutoEnrollStatus_WithPolicyRequirementsEnabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue( + SutProvider sutProvider, User user, Organization organization, - OrganizationUser organizationUser - ) + OrganizationUser organizationUser) { - var policyRequirement = new ResetPasswordPolicyRequirement() { AutoEnrollOrganizations = [organization.Id] }; + var policyRequirement = new ResetPasswordPolicyRequirement { AutoEnrollOrganizations = [organization.Id] }; - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); - _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); - _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); - _policyRequirementQuery.GetAsync(user.Id).Returns(policyRequirement); + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); + sutProvider.GetDependency().GetAsync(user.Id).Returns(policyRequirement); - var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString()); + var result = await sutProvider.Sut.GetAutoEnrollStatus(organization.Id.ToString()); - await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any()); - await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString()); - await _policyRequirementQuery.Received(1).GetAsync(user.Id); + await sutProvider.GetDependency().Received(1).GetUserByPrincipalAsync(Arg.Any()); + await sutProvider.GetDependency().Received(1).GetByIdentifierAsync(organization.Id.ToString()); + await sutProvider.GetDependency().Received(1).GetAsync(user.Id); Assert.True(result.ResetPasswordEnabled); Assert.Equal(result.Id, organization.Id); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task GetAutoEnrollStatus_WithPolicyRequirementsDisabled_ReturnsOrganizationAutoEnrollStatus_WithResetPasswordEnabledTrue( - User user, - Organization organization, - OrganizationUser organizationUser -) + SutProvider sutProvider, + User user, + Organization organization, + OrganizationUser organizationUser) { + var policy = new Policy + { + Type = PolicyType.ResetPassword, + Enabled = true, + Data = "{\"AutoEnrollEnabled\": true}", + OrganizationId = organization.Id + }; - var policy = new Policy() { Type = PolicyType.ResetPassword, Enabled = true, Data = "{\"AutoEnrollEnabled\": true}", OrganizationId = organization.Id }; + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); + sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); + sutProvider.GetDependency().GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); - _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); - _organizationRepository.GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); - _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); - _organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); - _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); + var result = await sutProvider.Sut.GetAutoEnrollStatus(organization.Id.ToString()); - var result = await _sut.GetAutoEnrollStatus(organization.Id.ToString()); - - await _userService.Received(1).GetUserByPrincipalAsync(Arg.Any()); - await _organizationRepository.Received(1).GetByIdentifierAsync(organization.Id.ToString()); - await _policyRequirementQuery.Received(0).GetAsync(user.Id); - await _policyRepository.Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + await sutProvider.GetDependency().Received(1).GetUserByPrincipalAsync(Arg.Any()); + await sutProvider.GetDependency().Received(1).GetByIdentifierAsync(organization.Id.ToString()); + await sutProvider.GetDependency().Received(0).GetAsync(user.Id); + await sutProvider.GetDependency().Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); Assert.True(result.ResetPasswordEnabled); } - [Theory, AutoData] + [Theory, BitAutoData] public async Task PutCollectionManagement_ValidRequest_Success( + SutProvider sutProvider, Organization organization, OrganizationCollectionManagementUpdateRequestModel model) { // Arrange - _currentContext.OrganizationOwner(organization.Id).Returns(true); + sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(true); - var plan = StaticStore.GetPlan(PlanType.EnterpriseAnnually); - _pricingClient.GetPlan(Arg.Any()).Returns(plan); + var plan = MockPlans.Get(PlanType.EnterpriseAnnually); + sutProvider.GetDependency().GetPlan(Arg.Any()).Returns(plan); - _organizationService + sutProvider.GetDependency() .UpdateCollectionManagementSettingsAsync( organization.Id, Arg.Is(s => @@ -319,10 +249,10 @@ public class OrganizationsControllerTests : IDisposable .Returns(organization); // Act - await _sut.PutCollectionManagement(organization.Id, model); + await sutProvider.Sut.PutCollectionManagement(organization.Id, model); // Assert - await _organizationService + await sutProvider.GetDependency() .Received(1) .UpdateCollectionManagementSettingsAsync( organization.Id, diff --git a/test/Api.Test/AdminConsole/Controllers/ProviderClientsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/ProviderClientsControllerTests.cs index c7c749effd..259797dfb3 100644 --- a/test/Api.Test/AdminConsole/Controllers/ProviderClientsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/ProviderClientsControllerTests.cs @@ -66,8 +66,8 @@ public class ProviderClientsControllerTests signup.Plan == requestBody.PlanType && signup.AdditionalSeats == requestBody.Seats && signup.OwnerKey == requestBody.Key && - signup.PublicKey == requestBody.KeyPair.PublicKey && - signup.PrivateKey == requestBody.KeyPair.EncryptedPrivateKey && + signup.Keys.PublicKey == requestBody.KeyPair.PublicKey && + signup.Keys.WrappedPrivateKey == requestBody.KeyPair.EncryptedPrivateKey && signup.CollectionName == requestBody.CollectionName), requestBody.OwnerEmail, user) diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs deleted file mode 100644 index 74fe75a9d7..0000000000 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs +++ /dev/null @@ -1,266 +0,0 @@ -using System.Text.Json; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations; - -public class OrganizationIntegrationConfigurationRequestModelTests -{ - [Fact] - public void IsValidForType_CloudBillingSyncIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.CloudBillingSync)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Theory] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullHecConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - } - - [Fact] - public void IsValidForType_NullHecConfiguration_ReturnsTrue() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = null, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Hec)); - } - - [Theory] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullDatadogConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - } - - [Fact] - public void IsValidForType_NullDatadogConfiguration_ReturnsTrue() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = null, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Datadog)); - } - - [Theory] - [InlineData(data: null)] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyTemplate_ReturnsFalse(string? template) - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = template - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - } - - [Fact] - public void IsValidForType_InvalidJsonConfiguration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{not valid json}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); - Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); - Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); - } - - - [Fact] - public void IsValidForType_InvalidJsonFilters_ReturnsFalse() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = "{Not valid json", - Template = "template" - }; - - Assert.False(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ScimIntegration_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Scim)); - } - - [Fact] - public void IsValidForType_ValidSlackConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new SlackIntegrationConfiguration(ChannelId: "C12345")); - - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidSlackConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Slack)); - } - - [Fact] - public void IsValidForType_ValidNoAuthWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfiguration_ReturnsTrue() - { - var config = JsonSerializer.Serialize(value: new WebhookIntegrationConfiguration( - Uri: new Uri("https://localhost"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.True(condition: model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_ValidWebhookConfigurationWithFilters_ReturnsTrue() - { - var config = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( - Uri: new Uri("https://example.com"), - Scheme: "Bearer", - Token: "AUTH-TOKEN")); - var filters = JsonSerializer.Serialize(new IntegrationFilterGroup() - { - AndOperator = true, - Rules = [ - new IntegrationFilterRule() - { - Operation = IntegrationFilterOperation.Equals, - Property = "CollectionId", - Value = Guid.NewGuid() - } - ], - Groups = [] - }); - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Filters = filters, - Template = "template" - }; - - Assert.True(model.IsValidForType(IntegrationType.Webhook)); - } - - [Fact] - public void IsValidForType_UnknownIntegrationType_ReturnsFalse() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = "{}", - Template = "template" - }; - - var unknownType = (IntegrationType)999; - - Assert.False(condition: model.IsValidForType(unknownType)); - } -} diff --git a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs index 057680425a..163d66aeb4 100644 --- a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs +++ b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs @@ -24,11 +24,11 @@ public class SavePolicyRequestTests currentContext.OrganizationOwner(organizationId).Returns(true); var testData = new Dictionary { { "test", "value" } }; + var policyType = PolicyType.TwoFactorAuthentication; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.TwoFactorAuthentication, Enabled = true, Data = testData }, @@ -36,7 +36,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Equal(PolicyType.TwoFactorAuthentication, result.PolicyUpdate.Type); @@ -54,7 +54,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_WithNullData_HandlesCorrectly( + public async Task ToSavePolicyModelAsync_WithEmptyData_HandlesCorrectly( Guid organizationId, Guid userId) { @@ -63,19 +63,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(false); + var policyType = PolicyType.SingleOrg; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Null(result.PolicyUpdate.Data); @@ -95,19 +93,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.SingleOrg; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Null(result.PolicyUpdate.Data); @@ -128,13 +124,12 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -143,7 +138,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.IsType(result.Metadata); @@ -152,7 +147,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithNullMetadata_ReturnsEmptyMetadata( + public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithEmptyMetadata_ReturnsEmptyMetadata( Guid organizationId, Guid userId) { @@ -161,19 +156,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null - }, - Metadata = null + Enabled = true + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); @@ -200,12 +193,11 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); - + var policyType = PolicyType.ResetPassword; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.ResetPassword, Enabled = true, Data = _complexData }, @@ -213,7 +205,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert var deserializedData = JsonSerializer.Deserialize>(result.PolicyUpdate.Data); @@ -241,13 +233,12 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.MaximumVaultTimeout; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.MaximumVaultTimeout, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -256,7 +247,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); @@ -274,20 +265,18 @@ public class SavePolicyRequestTests currentContext.OrganizationOwner(organizationId).Returns(true); var errorDictionary = BuildErrorDictionary(); - + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = errorDictionary }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs new file mode 100644 index 0000000000..30b0ccc272 --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs @@ -0,0 +1,151 @@ +using Bit.Api.AdminConsole.Models.Response; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response; + +public class ProfileOrganizationResponseModelTests +{ + [Theory, BitAutoData] + public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization) + { + var userId = Guid.NewGuid(); + var organizationUserId = Guid.NewGuid(); + var providerId = Guid.NewGuid(); + var organizationIdsClaimingUser = new[] { organization.Id }; + + var organizationDetails = new OrganizationUserOrganizationDetails + { + OrganizationId = organization.Id, + UserId = userId, + OrganizationUserId = organizationUserId, + Name = organization.Name, + Enabled = organization.Enabled, + Identifier = organization.Identifier, + PlanType = organization.PlanType, + UsePolicies = organization.UsePolicies, + UseSso = organization.UseSso, + UseKeyConnector = organization.UseKeyConnector, + UseScim = organization.UseScim, + UseGroups = organization.UseGroups, + UseDirectory = organization.UseDirectory, + UseEvents = organization.UseEvents, + UseTotp = organization.UseTotp, + Use2fa = organization.Use2fa, + UseApi = organization.UseApi, + UseResetPassword = organization.UseResetPassword, + UseSecretsManager = organization.UseSecretsManager, + UsePasswordManager = organization.UsePasswordManager, + UsersGetPremium = organization.UsersGetPremium, + UseCustomPermissions = organization.UseCustomPermissions, + UseRiskInsights = organization.UseRiskInsights, + UsePhishingBlocker = organization.UsePhishingBlocker, + UseOrganizationDomains = organization.UseOrganizationDomains, + UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, + SelfHost = organization.SelfHost, + Seats = organization.Seats, + MaxCollections = organization.MaxCollections, + MaxStorageGb = organization.MaxStorageGb, + Key = "organization-key", + PublicKey = "public-key", + PrivateKey = "private-key", + LimitCollectionCreation = organization.LimitCollectionCreation, + LimitCollectionDeletion = organization.LimitCollectionDeletion, + LimitItemDeletion = organization.LimitItemDeletion, + AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems, + ProviderId = providerId, + ProviderName = "Test Provider", + ProviderType = ProviderType.Msp, + SsoEnabled = true, + SsoConfig = new SsoConfigurationData + { + MemberDecryptionType = MemberDecryptionType.KeyConnector, + KeyConnectorUrl = "https://keyconnector.example.com" + }.Serialize(), + SsoExternalId = "external-sso-id", + Permissions = CoreHelpers.ClassToJsonData(new Core.Models.Data.Permissions { ManageUsers = true }), + ResetPasswordKey = "reset-password-key", + FamilySponsorshipFriendlyName = "Family Sponsorship", + FamilySponsorshipLastSyncDate = DateTime.UtcNow.AddDays(-1), + FamilySponsorshipToDelete = false, + FamilySponsorshipValidUntil = DateTime.UtcNow.AddYears(1), + IsAdminInitiated = true, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Owner, + AccessSecretsManager = true, + SmSeats = 5, + SmServiceAccounts = 10 + }; + + var result = new ProfileOrganizationResponseModel(organizationDetails, organizationIdsClaimingUser); + + Assert.Equal("profileOrganization", result.Object); + Assert.Equal(organization.Id, result.Id); + Assert.Equal(userId, result.UserId); + Assert.Equal(organization.Name, result.Name); + Assert.Equal(organization.Enabled, result.Enabled); + Assert.Equal(organization.Identifier, result.Identifier); + Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType); + Assert.Equal(organization.UsePolicies, result.UsePolicies); + Assert.Equal(organization.UseSso, result.UseSso); + Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector); + Assert.Equal(organization.UseScim, result.UseScim); + Assert.Equal(organization.UseGroups, result.UseGroups); + Assert.Equal(organization.UseDirectory, result.UseDirectory); + Assert.Equal(organization.UseEvents, result.UseEvents); + Assert.Equal(organization.UseTotp, result.UseTotp); + Assert.Equal(organization.Use2fa, result.Use2fa); + Assert.Equal(organization.UseApi, result.UseApi); + Assert.Equal(organization.UseResetPassword, result.UseResetPassword); + Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager); + Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager); + Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium); + Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions); + Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy); + Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights); + Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains); + Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies); + Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation); + Assert.Equal(organization.SelfHost, result.SelfHost); + Assert.Equal(organization.Seats, result.Seats); + Assert.Equal(organization.MaxCollections, result.MaxCollections); + Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb); + Assert.Equal(organizationDetails.Key, result.Key); + Assert.True(result.HasPublicAndPrivateKeys); + Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation); + Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion); + Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion); + Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems); + Assert.Equal(organizationDetails.ProviderId, result.ProviderId); + Assert.Equal(organizationDetails.ProviderName, result.ProviderName); + Assert.Equal(organizationDetails.ProviderType, result.ProviderType); + Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled); + Assert.True(result.KeyConnectorEnabled); + Assert.Equal("https://keyconnector.example.com", result.KeyConnectorUrl); + Assert.Equal(MemberDecryptionType.KeyConnector, result.SsoMemberDecryptionType); + Assert.True(result.SsoBound); + Assert.Equal(organizationDetails.Status, result.Status); + Assert.Equal(organizationDetails.Type, result.Type); + Assert.Equal(organizationDetails.OrganizationUserId, result.OrganizationUserId); + Assert.True(result.UserIsClaimedByOrganization); + Assert.NotNull(result.Permissions); + Assert.True(result.ResetPasswordEnrolled); + Assert.Equal(organizationDetails.AccessSecretsManager, result.AccessSecretsManager); + Assert.Equal(organizationDetails.FamilySponsorshipFriendlyName, result.FamilySponsorshipFriendlyName); + Assert.Equal(organizationDetails.FamilySponsorshipLastSyncDate, result.FamilySponsorshipLastSyncDate); + Assert.Equal(organizationDetails.FamilySponsorshipToDelete, result.FamilySponsorshipToDelete); + Assert.Equal(organizationDetails.FamilySponsorshipValidUntil, result.FamilySponsorshipValidUntil); + Assert.True(result.IsAdminInitiated); + Assert.False(result.FamilySponsorshipAvailable); + } +} diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs new file mode 100644 index 0000000000..1757f9d983 --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs @@ -0,0 +1,130 @@ +using Bit.Api.AdminConsole.Models.Response; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response; + +public class ProfileProviderOrganizationResponseModelTests +{ + [Theory, BitAutoData] + public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization) + { + var userId = Guid.NewGuid(); + var providerId = Guid.NewGuid(); + var providerUserId = Guid.NewGuid(); + + var organizationDetails = new ProviderUserOrganizationDetails + { + OrganizationId = organization.Id, + UserId = userId, + Name = organization.Name, + Enabled = organization.Enabled, + Identifier = organization.Identifier, + PlanType = organization.PlanType, + UsePolicies = organization.UsePolicies, + UseSso = organization.UseSso, + UseKeyConnector = organization.UseKeyConnector, + UseScim = organization.UseScim, + UseGroups = organization.UseGroups, + UseDirectory = organization.UseDirectory, + UseEvents = organization.UseEvents, + UseTotp = organization.UseTotp, + Use2fa = organization.Use2fa, + UseApi = organization.UseApi, + UseResetPassword = organization.UseResetPassword, + UseSecretsManager = organization.UseSecretsManager, + UsePasswordManager = organization.UsePasswordManager, + UsersGetPremium = organization.UsersGetPremium, + UseCustomPermissions = organization.UseCustomPermissions, + UseRiskInsights = organization.UseRiskInsights, + UsePhishingBlocker = organization.UsePhishingBlocker, + UseOrganizationDomains = organization.UseOrganizationDomains, + UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, + SelfHost = organization.SelfHost, + Seats = organization.Seats, + MaxCollections = organization.MaxCollections, + MaxStorageGb = organization.MaxStorageGb, + Key = "provider-org-key", + PublicKey = "public-key", + PrivateKey = "private-key", + LimitCollectionCreation = organization.LimitCollectionCreation, + LimitCollectionDeletion = organization.LimitCollectionDeletion, + LimitItemDeletion = organization.LimitItemDeletion, + AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems, + ProviderId = providerId, + ProviderName = "Test MSP Provider", + ProviderType = ProviderType.Msp, + SsoEnabled = true, + SsoConfig = new SsoConfigurationData + { + MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption + }.Serialize(), + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin, + ProviderUserId = providerUserId + }; + + var result = new ProfileProviderOrganizationResponseModel(organizationDetails); + + Assert.Equal("profileProviderOrganization", result.Object); + Assert.Equal(organization.Id, result.Id); + Assert.Equal(userId, result.UserId); + Assert.Equal(organization.Name, result.Name); + Assert.Equal(organization.Enabled, result.Enabled); + Assert.Equal(organization.Identifier, result.Identifier); + Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType); + Assert.Equal(organization.UsePolicies, result.UsePolicies); + Assert.Equal(organization.UseSso, result.UseSso); + Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector); + Assert.Equal(organization.UseScim, result.UseScim); + Assert.Equal(organization.UseGroups, result.UseGroups); + Assert.Equal(organization.UseDirectory, result.UseDirectory); + Assert.Equal(organization.UseEvents, result.UseEvents); + Assert.Equal(organization.UseTotp, result.UseTotp); + Assert.Equal(organization.Use2fa, result.Use2fa); + Assert.Equal(organization.UseApi, result.UseApi); + Assert.Equal(organization.UseResetPassword, result.UseResetPassword); + Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager); + Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager); + Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium); + Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions); + Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy); + Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights); + Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains); + Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies); + Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation); + Assert.Equal(organization.SelfHost, result.SelfHost); + Assert.Equal(organization.Seats, result.Seats); + Assert.Equal(organization.MaxCollections, result.MaxCollections); + Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb); + Assert.Equal(organizationDetails.Key, result.Key); + Assert.True(result.HasPublicAndPrivateKeys); + Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation); + Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion); + Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion); + Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems); + Assert.Equal(organizationDetails.ProviderId, result.ProviderId); + Assert.Equal(organizationDetails.ProviderName, result.ProviderName); + Assert.Equal(organizationDetails.ProviderType, result.ProviderType); + Assert.Equal(OrganizationUserStatusType.Confirmed, result.Status); + Assert.Equal(OrganizationUserType.Owner, result.Type); + Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled); + Assert.False(result.KeyConnectorEnabled); + Assert.Null(result.KeyConnectorUrl); + Assert.Equal(MemberDecryptionType.TrustedDeviceEncryption, result.SsoMemberDecryptionType); + Assert.False(result.SsoBound); + Assert.NotNull(result.Permissions); + Assert.False(result.Permissions.ManageUsers); + Assert.False(result.ResetPasswordEnrolled); + Assert.False(result.AccessSecretsManager); + } +} diff --git a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs new file mode 100644 index 0000000000..bd10eab617 --- /dev/null +++ b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -0,0 +1,49 @@ +using Bit.Api.AdminConsole.Public.Controllers; +using Bit.Api.AdminConsole.Public.Models.Request; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.Context; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Public.Controllers; + +[ControllerCustomize(typeof(PoliciesController))] +[SutProviderCustomize] +public class PoliciesControllerTests +{ + [Theory] + [BitAutoData] + public async Task Put_UsesVNextSavePolicyCommand( + Guid organizationId, + PolicyType policyType, + PolicyUpdateRequestModel model, + Policy policy, + SutProvider sutProvider) + { + // Arrange + policy.Data = null; + sutProvider.GetDependency() + .OrganizationId.Returns(organizationId); + sutProvider.GetDependency() + .SaveAsync(Arg.Any()) + .Returns(policy); + + // Act + await sutProvider.Sut.Put(policyType, model); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.OrganizationId == organizationId && + m.PolicyUpdate.Type == policyType && + m.PolicyUpdate.Enabled == model.Enabled.GetValueOrDefault() && + m.PerformedBy is SystemUser)); + } +} diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index e81d51281d..300a4d823d 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -11,6 +11,9 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -33,9 +36,10 @@ public class AccountsControllerTests : IDisposable private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly IFeatureService _featureService; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; - + private readonly IUserRepository _userRepository; public AccountsControllerTests() { @@ -48,8 +52,10 @@ public class AccountsControllerTests : IDisposable _twoFactorIsEnabledQuery = Substitute.For(); _tdeOffboardingPasswordCommand = Substitute.For(); _featureService = Substitute.For(); + _userAccountKeysQuery = Substitute.For(); _twoFactorEmailService = Substitute.For(); _changeKdfCommand = Substitute.For(); + _userRepository = Substitute.For(); _sut = new AccountsController( _organizationService, @@ -61,8 +67,10 @@ public class AccountsControllerTests : IDisposable _tdeOffboardingPasswordCommand, _twoFactorIsEnabledQuery, _featureService, + _userAccountKeysQuery, _twoFactorEmailService, - _changeKdfCommand + _changeKdfCommand, + _userRepository ); } @@ -614,6 +622,16 @@ public class AccountsControllerTests : IDisposable await _twoFactorEmailService.Received(1).SendNewDeviceVerificationEmailAsync(user); } + [Theory] + [BitAutoData] + public async Task PostKdf_UserNotFound_ShouldFail(PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(null)); + + // Act + await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + } + [Theory] [BitAutoData] public async Task PostKdf_WithNullAuthenticationData_ShouldFail( @@ -623,7 +641,9 @@ public class AccountsControllerTests : IDisposable model.AuthenticationData = null; // Act - await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.Contains("AuthenticationData and UnlockData must be provided.", exception.Message); } [Theory] @@ -635,7 +655,72 @@ public class AccountsControllerTests : IDisposable model.UnlockData = null; // Act - await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.Contains("AuthenticationData and UnlockData must be provided.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task PostKdf_ChangeKdfFailed_ShouldFail( + User user, PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + _changeKdfCommand.ChangeKdfAsync(Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Failed(new IdentityError { Description = "Change KDF failed" }))); + + // Act + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.NotNull(exception.ModelState); + Assert.Contains("Change KDF failed", + exception.ModelState.Values.SelectMany(x => x.Errors).Select(x => x.ErrorMessage)); + } + + [Theory] + [BitAutoData] + public async Task PostKdf_ChangeKdfSuccess_NoError( + User user, PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + _changeKdfCommand.ChangeKdfAsync(Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + + // Act + await _sut.PostKdf(model); + } + + [Theory] + [BitAutoData] + public async Task PostKeys_NoUser_Errors(KeysRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(null)); + + await Assert.ThrowsAsync(() => _sut.PostKeys(model)); + } + + [Theory] + [BitAutoData("existing", "existing")] + [BitAutoData((string)null, "existing")] + [BitAutoData("", "existing")] + [BitAutoData(" ", "existing")] + [BitAutoData("existing", null)] + [BitAutoData("existing", "")] + [BitAutoData("existing", " ")] + public async Task PostKeys_UserAlreadyHasKeys_Errors(string? existingPrivateKey, string? existingPublicKey, + KeysRequestModel model) + { + var user = GenerateExampleUser(); + user.PrivateKey = existingPrivateKey; + user.PublicKey = existingPublicKey; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + + var exception = await Assert.ThrowsAsync(() => _sut.PostKeys(model)); + + Assert.NotNull(exception.Message); + Assert.Contains("User has existing keypair", exception.Message); } // Below are helper functions that currently belong to this @@ -688,5 +773,77 @@ public class AccountsControllerTests : IDisposable _userService.GetUserByIdAsync(Arg.Any()) .Returns(Task.FromResult((User)null)); } + + [Theory, BitAutoData] + public async Task PostKeys_WithAccountKeys_CallsSetV2AccountCryptographicState( + User user, + KeysRequestModel model) + { + // Arrange + user.PublicKey = null; + user.PrivateKey = null; + model.AccountKeys = new AccountKeysRequestModel + { + UserKeyEncryptedAccountPrivateKey = "wrapped-private-key", + AccountPublicKey = "public-key", + PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairRequestModel + { + PublicKey = "public-key", + WrappedPrivateKey = "wrapped-private-key", + SignedPublicKey = "signed-public-key" + }, + SignatureKeyPair = new SignatureKeyPairRequestModel + { + VerifyingKey = "verifying-key", + SignatureAlgorithm = "ed25519", + WrappedSigningKey = "wrapped-signing-key" + }, + SecurityState = new SecurityStateModel + { + SecurityState = "security-state", + SecurityVersion = 2 + } + }; + + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + // Act + var result = await _sut.PostKeys(model); + + // Assert + await _userRepository.Received(1).SetV2AccountCryptographicStateAsync( + user.Id, + Arg.Any()); + await _userService.DidNotReceiveWithAnyArgs().SaveUserAsync(Arg.Any()); + Assert.NotNull(result); + Assert.Equal("keys", result.Object); + } + + [Theory, BitAutoData] + public async Task PostKeys_WithoutAccountKeys_CallsSaveUser( + User user, + KeysRequestModel model) + { + // Arrange + user.PublicKey = null; + user.PrivateKey = null; + model.AccountKeys = null; + model.PublicKey = "public-key"; + model.EncryptedPrivateKey = "encrypted-private-key"; + + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + // Act + var result = await _sut.PostKeys(model); + + // Assert + await _userService.Received(1).SaveUserAsync(Arg.Is(u => + u.PublicKey == model.PublicKey && + u.PrivateKey == model.EncryptedPrivateKey)); + await _userRepository.DidNotReceiveWithAnyArgs() + .SetV2AccountCryptographicStateAsync(Arg.Any(), Arg.Any()); + Assert.NotNull(result); + Assert.Equal("keys", result.Object); + } } diff --git a/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs new file mode 100644 index 0000000000..16b9b26436 --- /dev/null +++ b/test/Api.Test/Billing/Controllers/AccountsControllerTests.cs @@ -0,0 +1,804 @@ +using System.Security.Claims; +using Bit.Api.Billing.Controllers; +using Bit.Core; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Models.Business; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Test.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Api.Test.Billing.Controllers; + +[SubscriptionInfoCustomize] +public class AccountsControllerTests : IDisposable +{ + private const string TestMilestone2CouponId = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount; + + private readonly IUserService _userService; + private readonly IFeatureService _featureService; + private readonly IStripePaymentService _paymentService; + private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; + private readonly ILicensingService _licensingService; + private readonly GlobalSettings _globalSettings; + private readonly AccountsController _sut; + + public AccountsControllerTests() + { + _userService = Substitute.For(); + _featureService = Substitute.For(); + _paymentService = Substitute.For(); + _twoFactorIsEnabledQuery = Substitute.For(); + _userAccountKeysQuery = Substitute.For(); + _licensingService = Substitute.For(); + _globalSettings = new GlobalSettings { SelfHosted = false }; + + _sut = new AccountsController( + _userService, + _twoFactorIsEnabledQuery, + _userAccountKeysQuery, + _featureService, + _licensingService + ); + } + + public void Dispose() + { + _sut?.Dispose(); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WhenFeatureFlagEnabled_IncludesDiscount( + User user, + SubscriptionInfo subscriptionInfo, + UserLicense license) + { + // Arrange + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = TestMilestone2CouponId, + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + + user.Gateway = GatewayType.Stripe; // User has payment gateway + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WhenFeatureFlagDisabled_ExcludesDiscount( + User user, + SubscriptionInfo subscriptionInfo, + UserLicense license) + { + // Arrange + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = TestMilestone2CouponId, + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(false); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + + user.Gateway = GatewayType.Stripe; // User has payment gateway + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); // Should be null when feature flag is disabled + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNonMatchingCouponId_ExcludesDiscount( + User user, + SubscriptionInfo subscriptionInfo, + UserLicense license) + { + // Arrange + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = "different-coupon-id", // Non-matching coupon ID + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + + user.Gateway = GatewayType.Stripe; // User has payment gateway + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); // Should be null when coupon ID doesn't match + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WhenSelfHosted_ReturnsBasicResponse(User user) + { + // Arrange + var selfHostedSettings = new GlobalSettings { SelfHosted = true }; + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + + // Act + var result = await _sut.GetSubscriptionAsync(selfHostedSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); + await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WhenNoGateway_ExcludesDiscount(User user, UserLicense license) + { + // Arrange + user.Gateway = null; // No gateway configured + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _userService.GenerateLicenseAsync(user).Returns(license); + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); // Should be null when no gateway + await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithInactiveDiscount_ExcludesDiscount( + User user, + SubscriptionInfo subscriptionInfo, + UserLicense license) + { + // Arrange + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = TestMilestone2CouponId, + Active = false, // Inactive discount + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + + user.Gateway = GatewayType.Stripe; // User has payment gateway + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); // Should be null when discount is inactive + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_FullPipeline_ConvertsStripeDiscountToApiResponse( + User user, + UserLicense license) + { + // Arrange - Create a Stripe Discount object with real structure + var stripeDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 25m, + AmountOff = 1400, // 1400 cents = $14.00 + AppliesTo = new CouponAppliesTo + { + Products = new List { "prod_premium", "prod_families" } + } + }, + End = null // Active discount + }; + + // Convert Stripe Discount to BillingCustomerDiscount (simulating what StripePaymentService does) + var billingDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount); + + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingDiscount + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + + user.Gateway = GatewayType.Stripe; + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Verify full pipeline conversion + Assert.NotNull(result); + Assert.NotNull(result.CustomerDiscount); + + // Verify Stripe data correctly converted to API response + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.True(result.CustomerDiscount.Active); + Assert.Equal(25m, result.CustomerDiscount.PercentOff); + + // Verify cents-to-dollars conversion (1400 cents -> $14.00) + Assert.Equal(14.00m, result.CustomerDiscount.AmountOff); + + // Verify AppliesTo products are preserved + Assert.NotNull(result.CustomerDiscount.AppliesTo); + Assert.Equal(2, result.CustomerDiscount.AppliesTo.Count()); + Assert.Contains("prod_premium", result.CustomerDiscount.AppliesTo); + Assert.Contains("prod_families", result.CustomerDiscount.AppliesTo); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_FullPipeline_WithFeatureFlagToggle_ControlsVisibility( + User user, + UserLicense license) + { + // Arrange - Create Stripe Discount + var stripeDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 20m + }, + End = null + }; + + var billingDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount); + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingDiscount + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act & Assert - Feature flag ENABLED + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + var resultWithFlag = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + Assert.NotNull(resultWithFlag.CustomerDiscount); + + // Act & Assert - Feature flag DISABLED + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(false); + var resultWithoutFlag = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + Assert.Null(resultWithoutFlag.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_CompletePipelineFromStripeToApiResponse( + User user, + UserLicense license) + { + // Arrange - Create a real Stripe Discount object as it would come from Stripe API + var stripeDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 30m, + AmountOff = 2000, // 2000 cents = $20.00 + AppliesTo = new CouponAppliesTo + { + Products = new List { "prod_premium", "prod_families", "prod_teams" } + } + }, + End = null // Active discount (no end date) + }; + + // Step 1: Map Stripe Discount through SubscriptionInfo.BillingCustomerDiscount + // This simulates what StripePaymentService.GetSubscriptionAsync does + var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount); + + // Verify the mapping worked correctly + Assert.Equal(TestMilestone2CouponId, billingCustomerDiscount.Id); + Assert.True(billingCustomerDiscount.Active); + Assert.Equal(30m, billingCustomerDiscount.PercentOff); + Assert.Equal(20.00m, billingCustomerDiscount.AmountOff); // Converted from cents + Assert.NotNull(billingCustomerDiscount.AppliesTo); + Assert.Equal(3, billingCustomerDiscount.AppliesTo.Count); + + // Step 2: Create SubscriptionInfo with the mapped discount + // This simulates what StripePaymentService returns + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingCustomerDiscount + }; + + // Step 3: Set up controller dependencies + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act - Step 4: Call AccountsController.GetSubscriptionAsync + // This exercises the complete pipeline: + // - Retrieves subscriptionInfo from paymentService (with discount from Stripe) + // - Maps through SubscriptionInfo.BillingCustomerDiscount (already done above) + // - Filters in SubscriptionResponseModel constructor (based on feature flag, coupon ID, active status) + // - Returns via AccountsController + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Verify the complete pipeline worked end-to-end + Assert.NotNull(result); + Assert.NotNull(result.CustomerDiscount); + + // Verify Stripe Discount → SubscriptionInfo.BillingCustomerDiscount mapping + // (verified above, but confirming it made it through) + + // Verify SubscriptionInfo.BillingCustomerDiscount → SubscriptionResponseModel.BillingCustomerDiscount filtering + // The filter should pass because: + // - includeMilestone2Discount = true (feature flag enabled) + // - subscription.CustomerDiscount != null + // - subscription.CustomerDiscount.Id == Milestone2SubscriptionDiscount + // - subscription.CustomerDiscount.Active = true + Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id); + Assert.True(result.CustomerDiscount.Active); + Assert.Equal(30m, result.CustomerDiscount.PercentOff); + Assert.Equal(20.00m, result.CustomerDiscount.AmountOff); // Verify cents-to-dollars conversion + + // Verify AppliesTo products are preserved through the entire pipeline + Assert.NotNull(result.CustomerDiscount.AppliesTo); + Assert.Equal(3, result.CustomerDiscount.AppliesTo.Count()); + Assert.Contains("prod_premium", result.CustomerDiscount.AppliesTo); + Assert.Contains("prod_families", result.CustomerDiscount.AppliesTo); + Assert.Contains("prod_teams", result.CustomerDiscount.AppliesTo); + + // Verify the payment service was called correctly + await _paymentService.Received(1).GetSubscriptionAsync(user); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_MultipleDiscountsInSubscription_PrefersCustomerDiscount( + User user, + UserLicense license) + { + // Arrange - Create Stripe subscription with multiple discounts + // Customer discount should be preferred over subscription discounts + var customerDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 30m, + AmountOff = null + }, + End = null + }; + + var subscriptionDiscount1 = new Discount + { + Coupon = new Coupon + { + Id = "other-coupon-1", + PercentOff = 10m + }, + End = null + }; + + var subscriptionDiscount2 = new Discount + { + Coupon = new Coupon + { + Id = "other-coupon-2", + PercentOff = 15m + }, + End = null + }; + + // Map through SubscriptionInfo.BillingCustomerDiscount + var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customerDiscount); + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingCustomerDiscount + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Should use customer discount, not subscription discounts + Assert.NotNull(result); + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id); + Assert.Equal(30m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_BothPercentOffAndAmountOffPresent_HandlesEdgeCase( + User user, + UserLicense license) + { + // Arrange - Edge case: Stripe coupon with both PercentOff and AmountOff + // This tests the scenario mentioned in BillingCustomerDiscountTests.cs line 212-232 + var stripeDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 25m, + AmountOff = 2000, // 2000 cents = $20.00 + AppliesTo = new CouponAppliesTo + { + Products = new List { "prod_premium" } + } + }, + End = null + }; + + // Map through SubscriptionInfo.BillingCustomerDiscount + var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount); + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingCustomerDiscount + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Both values should be preserved through the pipeline + Assert.NotNull(result); + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id); + Assert.Equal(25m, result.CustomerDiscount.PercentOff); + Assert.Equal(20.00m, result.CustomerDiscount.AmountOff); // Converted from cents + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_BillingSubscriptionMapsThroughPipeline( + User user, + UserLicense license) + { + // Arrange - Create Stripe subscription with subscription details + var stripeSubscription = new Subscription + { + Id = "sub_test123", + Status = "active", + TrialStart = DateTime.UtcNow.AddDays(-30), + TrialEnd = DateTime.UtcNow.AddDays(-20), + CanceledAt = null, + CancelAtPeriodEnd = false, + CollectionMethod = "charge_automatically" + }; + + // Map through SubscriptionInfo.BillingSubscription + var billingSubscription = new SubscriptionInfo.BillingSubscription(stripeSubscription); + var subscriptionInfo = new SubscriptionInfo + { + Subscription = billingSubscription, + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = TestMilestone2CouponId, + Active = true, + PercentOff = 20m + } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Verify BillingSubscription mapped through pipeline + Assert.NotNull(result); + Assert.NotNull(result.Subscription); + Assert.Equal("active", result.Subscription.Status); + Assert.Equal(14, result.Subscription.GracePeriod); // charge_automatically = 14 days + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_BillingUpcomingInvoiceMapsThroughPipeline( + User user, + UserLicense license) + { + // Arrange - Create Stripe invoice for upcoming invoice + var stripeInvoice = new Invoice + { + AmountDue = 2000, // 2000 cents = $20.00 + Created = DateTime.UtcNow.AddDays(1) + }; + + // Map through SubscriptionInfo.BillingUpcomingInvoice + var billingUpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice); + var subscriptionInfo = new SubscriptionInfo + { + UpcomingInvoice = billingUpcomingInvoice, + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = TestMilestone2CouponId, + Active = true, + PercentOff = 20m + } + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Verify BillingUpcomingInvoice mapped through pipeline + Assert.NotNull(result); + Assert.NotNull(result.UpcomingInvoice); + Assert.Equal(20.00m, result.UpcomingInvoice.Amount); // Converted from cents + Assert.NotNull(result.UpcomingInvoice.Date); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_IntegrationTest_CompletePipelineWithAllComponents( + User user, + UserLicense license) + { + // Arrange - Complete Stripe objects for full pipeline test + var stripeDiscount = new Discount + { + Coupon = new Coupon + { + Id = TestMilestone2CouponId, + PercentOff = 20m, + AmountOff = 1000, // $10.00 + AppliesTo = new CouponAppliesTo + { + Products = new List { "prod_premium", "prod_families" } + } + }, + End = null + }; + + var stripeSubscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically" + }; + + var stripeInvoice = new Invoice + { + AmountDue = 1500, // $15.00 + Created = DateTime.UtcNow.AddDays(7) + }; + + // Map through SubscriptionInfo (simulating StripePaymentService) + var billingCustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(stripeDiscount); + var billingSubscription = new SubscriptionInfo.BillingSubscription(stripeSubscription); + var billingUpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice); + + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = billingCustomerDiscount, + Subscription = billingSubscription, + UpcomingInvoice = billingUpcomingInvoice + }; + + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _paymentService.GetSubscriptionAsync(user).Returns(subscriptionInfo); + _userService.GenerateLicenseAsync(user, subscriptionInfo).Returns(license); + user.Gateway = GatewayType.Stripe; + + // Act - Full pipeline: Stripe → SubscriptionInfo → SubscriptionResponseModel → API response + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Verify all components mapped correctly through the pipeline + Assert.NotNull(result); + + // Verify discount + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(TestMilestone2CouponId, result.CustomerDiscount.Id); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + Assert.Equal(10.00m, result.CustomerDiscount.AmountOff); + Assert.NotNull(result.CustomerDiscount.AppliesTo); + Assert.Equal(2, result.CustomerDiscount.AppliesTo.Count()); + + // Verify subscription + Assert.NotNull(result.Subscription); + Assert.Equal("active", result.Subscription.Status); + Assert.Equal(14, result.Subscription.GracePeriod); + + // Verify upcoming invoice + Assert.NotNull(result.UpcomingInvoice); + Assert.Equal(15.00m, result.UpcomingInvoice.Amount); + Assert.NotNull(result.UpcomingInvoice.Date); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_SelfHosted_WithDiscountFlagEnabled_NeverIncludesDiscount(User user) + { + // Arrange - Self-hosted user with discount flag enabled (should still return null) + var selfHostedSettings = new GlobalSettings { SelfHosted = true }; + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); // Flag enabled + + // Act + var result = await _sut.GetSubscriptionAsync(selfHostedSettings, _paymentService); + + // Assert - Should never include discount for self-hosted, even with flag enabled + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); + await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_NullGateway_WithDiscountFlagEnabled_NeverIncludesDiscount( + User user, + UserLicense license) + { + // Arrange - User with null gateway and discount flag enabled (should still return null) + user.Gateway = null; // No gateway configured + var claimsPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); + _sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext { User = claimsPrincipal } + }; + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(user); + _userService.GenerateLicenseAsync(user).Returns(license); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); // Flag enabled + + // Act + var result = await _sut.GetSubscriptionAsync(_globalSettings, _paymentService); + + // Assert - Should never include discount when no gateway, even with flag enabled + Assert.NotNull(result); + Assert.Null(result.CustomerDiscount); + await _paymentService.DidNotReceive().GetSubscriptionAsync(Arg.Any()); + } +} diff --git a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs index 51866320ee..ee0bdc61e4 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationBillingControllerTests.cs @@ -1,12 +1,11 @@ using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http.HttpResults; @@ -53,19 +52,16 @@ public class OrganizationBillingControllerTests { sutProvider.GetDependency().OrganizationUser(organizationId).Returns(true); sutProvider.GetDependency().GetMetadata(organizationId) - .Returns(new OrganizationMetadata(true, true, true, true, true, true, true, null, null, null, 0)); + .Returns(new OrganizationMetadata(true, 10)); var result = await sutProvider.Sut.GetMetadataAsync(organizationId); - Assert.IsType>(result); + Assert.IsType>(result); - var response = ((Ok)result).Value; + var response = ((Ok)result).Value; - Assert.True(response.IsEligibleForSelfHost); - Assert.True(response.IsManaged); Assert.True(response.IsOnSecretsManagerStandalone); - Assert.True(response.IsSubscriptionUnpaid); - Assert.True(response.HasSubscription); + Assert.Equal(10, response.OrganizationOccupiedSeats); } [Theory, BitAutoData] @@ -107,7 +103,7 @@ public class OrganizationBillingControllerTests // Manually create a BillingHistoryInfo object to avoid requiring AutoFixture to create HttpResponseHeaders var billingInfo = new BillingHistoryInfo(); - sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); + sutProvider.GetDependency().GetBillingHistoryAsync(organization).Returns(billingInfo); // Act var result = await sutProvider.Sut.GetHistoryAsync(organizationId); diff --git a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs index 2ad7686c30..87334dc085 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -24,11 +24,11 @@ namespace Bit.Api.Test.Billing.Controllers; public class OrganizationSponsorshipsControllerTests { public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonConfirmedOrganizationUsersStatuses => Enum.GetValues() diff --git a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs index a776bbea22..9a3f57c3dc 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationsControllerTests.cs @@ -37,7 +37,7 @@ public class OrganizationsControllerTests : IDisposable private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationService _organizationService; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPaymentService _paymentService; + private readonly IStripePaymentService _paymentService; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IUserService _userService; private readonly IGetCloudOrganizationLicenseQuery _getCloudOrganizationLicenseQuery; @@ -59,7 +59,7 @@ public class OrganizationsControllerTests : IDisposable _organizationRepository = Substitute.For(); _organizationService = Substitute.For(); _organizationUserRepository = Substitute.For(); - _paymentService = Substitute.For(); + _paymentService = Substitute.For(); Substitute.For(); _ssoConfigRepository = Substitute.For(); Substitute.For(); diff --git a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs index 8c1dd60fb9..652e82c801 100644 --- a/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/ProviderBillingControllerTests.cs @@ -1,22 +1,20 @@ using Bit.Api.Billing.Controllers; -using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Models.Api; using Bit.Core.Models.BitStripe; -using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Http; @@ -123,7 +121,7 @@ public class ProviderBillingControllerTests } }; - sutProvider.GetDependency().InvoiceListAsync(Arg.Is( + sutProvider.GetDependency().ListInvoicesAsync(Arg.Is( options => options.Customer == provider.GatewayCustomerId)).Returns(invoices); @@ -270,7 +268,6 @@ public class ProviderBillingControllerTests var subscription = new Subscription { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, - CurrentPeriodEnd = new DateTime(now.Year, now.Month, daysInThisMonth), Customer = new Customer { Address = new Address @@ -291,20 +288,23 @@ public class ProviderBillingControllerTests Data = [ new SubscriptionItem { + CurrentPeriodEnd = new DateTime(now.Year, now.Month, daysInThisMonth), Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } }, new SubscriptionItem { + CurrentPeriodEnd = new DateTime(now.Year, now.Month, daysInThisMonth), Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } } ] }, - Status = "unpaid", + Status = "unpaid" }; - stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, Arg.Is( + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( options => options.Expand.Contains("customer.tax_ids") && + options.Expand.Contains("discounts") && options.Expand.Contains("test_clock"))).Returns(subscription); var daysInLastMonth = DateTime.DaysInMonth(oneMonthAgo.Year, oneMonthAgo.Month); @@ -318,7 +318,7 @@ public class ProviderBillingControllerTests Attempted = true }; - stripeAdapter.InvoiceSearchAsync(Arg.Is( + stripeAdapter.SearchInvoiceAsync(Arg.Is( options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) .Returns([overdueInvoice]); @@ -348,10 +348,10 @@ public class ProviderBillingControllerTests foreach (var providerPlan in providerPlans) { - var plan = StaticStore.GetPlan(providerPlan.PlanType); + var plan = MockPlans.Get(providerPlan.PlanType); sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); - sutProvider.GetDependency().PriceGetAsync(priceId) + sutProvider.GetDependency().GetPriceAsync(priceId) .Returns(new Price { UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 @@ -365,11 +365,11 @@ public class ProviderBillingControllerTests var response = ((Ok)result).Value; Assert.Equal(subscription.Status, response.Status); - Assert.Equal(subscription.CurrentPeriodEnd, response.CurrentPeriodEndDate); + Assert.Equal(subscription.GetCurrentPeriodEnd(), response.CurrentPeriodEndDate); Assert.Equal(subscription.Customer!.Discount!.Coupon!.PercentOff, response.DiscountPercentage); Assert.Equal(subscription.CollectionMethod, response.CollectionMethod); - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsPlan = MockPlans.Get(PlanType.TeamsMonthly); var providerTeamsPlan = response.Plans.FirstOrDefault(plan => plan.PlanName == teamsPlan.Name); Assert.NotNull(providerTeamsPlan); Assert.Equal(50, providerTeamsPlan.SeatMinimum); @@ -378,7 +378,7 @@ public class ProviderBillingControllerTests Assert.Equal(60 * teamsPlan.PasswordManager.ProviderPortalSeatPrice, providerTeamsPlan.Cost); Assert.Equal("Monthly", providerTeamsPlan.Cadence); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterprisePlan = MockPlans.Get(PlanType.EnterpriseMonthly); var providerEnterprisePlan = response.Plans.FirstOrDefault(plan => plan.PlanName == enterprisePlan.Name); Assert.NotNull(providerEnterprisePlan); Assert.Equal(100, providerEnterprisePlan.SeatMinimum); @@ -405,49 +405,116 @@ public class ProviderBillingControllerTests Assert.Equal(14, response.Suspension.GracePeriod); } - #endregion - - #region UpdateTaxInformationAsync - [Theory, BitAutoData] - public async Task UpdateTaxInformation_NoCountry_BadRequest( + public async Task GetSubscriptionAsync_SubscriptionLevelDiscount_Ok( Provider provider, - TaxInformationRequestBody requestBody, SutProvider sutProvider) { - ConfigureStableProviderAdminInputs(provider, sutProvider); + ConfigureStableProviderServiceUserInputs(provider, sutProvider); - requestBody.Country = null; + var stripeAdapter = sutProvider.GetDependency(); - var result = await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); + var now = DateTime.UtcNow; + var oneMonthAgo = now.AddMonths(-1); - Assert.IsType>(result); + var daysInThisMonth = DateTime.DaysInMonth(now.Year, now.Month); - var response = (BadRequest)result; + var subscription = new Subscription + { + CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + Customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345", + Line1 = "123 Example St.", + Line2 = "Unit 1", + City = "Example Town", + State = "NY" + }, + Balance = -100000, + Discount = null, // No customer-level discount + TaxIds = new StripeList { Data = [new TaxId { Value = "123456789" }] } + }, + Discounts = + [ + new Discount { Coupon = new Coupon { PercentOff = 15 } } // Subscription-level discount + ], + Items = new StripeList + { + Data = [ + new SubscriptionItem + { + CurrentPeriodEnd = new DateTime(now.Year, now.Month, daysInThisMonth), + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Enterprise } + }, + new SubscriptionItem + { + CurrentPeriodEnd = new DateTime(now.Year, now.Month, daysInThisMonth), + Price = new Price { Id = ProviderPriceAdapter.MSP.Active.Teams } + } + ] + }, + Status = "active" + }; - Assert.Equal("Country and postal code are required to update your tax information.", response.Value.Message); - } + stripeAdapter.GetSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is( + options => + options.Expand.Contains("customer.tax_ids") && + options.Expand.Contains("discounts") && + options.Expand.Contains("test_clock"))).Returns(subscription); - [Theory, BitAutoData] - public async Task UpdateTaxInformation_Ok( - Provider provider, - TaxInformationRequestBody requestBody, - SutProvider sutProvider) - { - ConfigureStableProviderAdminInputs(provider, sutProvider); + stripeAdapter.SearchInvoiceAsync(Arg.Is( + options => options.Query == $"subscription:'{subscription.Id}' status:'open'")) + .Returns([]); - await sutProvider.Sut.UpdateTaxInformationAsync(provider.Id, requestBody); + var providerPlans = new List + { + new () + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.TeamsMonthly, + SeatMinimum = 50, + PurchasedSeats = 10, + AllocatedSeats = 60 + }, + new () + { + Id = Guid.NewGuid(), + ProviderId = provider.Id, + PlanType = PlanType.EnterpriseMonthly, + SeatMinimum = 100, + PurchasedSeats = 0, + AllocatedSeats = 90 + } + }; - await sutProvider.GetDependency().Received(1).UpdateTaxInformation( - provider, Arg.Is( - options => - options.Country == requestBody.Country && - options.PostalCode == requestBody.PostalCode && - options.TaxId == requestBody.TaxId && - options.Line1 == requestBody.Line1 && - options.Line2 == requestBody.Line2 && - options.City == requestBody.City && - options.State == requestBody.State)); + sutProvider.GetDependency().GetByProviderId(provider.Id).Returns(providerPlans); + + foreach (var providerPlan in providerPlans) + { + var plan = MockPlans.Get(providerPlan.PlanType); + sutProvider.GetDependency().GetPlanOrThrow(providerPlan.PlanType).Returns(plan); + var priceId = ProviderPriceAdapter.GetPriceId(provider, subscription, providerPlan.PlanType); + sutProvider.GetDependency().GetPriceAsync(priceId) + .Returns(new Price + { + UnitAmountDecimal = plan.PasswordManager.ProviderPortalSeatPrice * 100 + }); + } + + var result = await sutProvider.Sut.GetSubscriptionAsync(provider.Id); + + Assert.IsType>(result); + + var response = ((Ok)result).Value; + + Assert.Equal(subscription.Status, response.Status); + Assert.Equal(subscription.GetCurrentPeriodEnd(), response.CurrentPeriodEndDate); + Assert.Equal(15, response.DiscountPercentage); // Verify subscription-level discount is used + Assert.Equal(subscription.CollectionMethod, response.CollectionMethod); } #endregion diff --git a/test/Api.Test/Controllers/PoliciesControllerTests.cs b/test/Api.Test/Controllers/PoliciesControllerTests.cs index f5f3eddd3b..efb9f7aaa9 100644 --- a/test/Api.Test/Controllers/PoliciesControllerTests.cs +++ b/test/Api.Test/Controllers/PoliciesControllerTests.cs @@ -1,10 +1,14 @@ using System.Security.Claims; using System.Text.Json; using Bit.Api.AdminConsole.Controllers; +using Bit.Api.AdminConsole.Models.Request; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Context; @@ -286,7 +290,7 @@ public class PoliciesControllerTests string token, string email, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -297,14 +301,15 @@ public class PoliciesControllerTests var decryptedToken = Substitute.For(); decryptedToken.Valid.Returns(false); - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); // Act & Assert await Assert.ThrowsAsync(() => @@ -320,7 +325,7 @@ public class PoliciesControllerTests string token, string email, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -333,14 +338,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); sutProvider.GetDependency() .GetByIdAsync(organizationUserId) @@ -361,7 +367,7 @@ public class PoliciesControllerTests string email, OrganizationUser orgUser, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -374,14 +380,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); orgUser.OrganizationId = Guid.Empty; @@ -404,7 +411,7 @@ public class PoliciesControllerTests string email, OrganizationUser orgUser, Organization organization - ) + ) { // Arrange organization.UsePolicies = true; @@ -417,14 +424,15 @@ public class PoliciesControllerTests decryptedToken.OrgUserId = organizationUserId; decryptedToken.OrgUserEmail = email; - var orgUserInviteTokenDataFactory = sutProvider.GetDependency>(); + var orgUserInviteTokenDataFactory = + sutProvider.GetDependency>(); orgUserInviteTokenDataFactory.TryUnprotect(token, out Arg.Any()) .Returns(x => - { - x[1] = decryptedToken; - return true; - }); + { + x[1] = decryptedToken; + return true; + }); orgUser.OrganizationId = orgId; sutProvider.GetDependency() @@ -455,4 +463,46 @@ public class PoliciesControllerTests Assert.Equal(enabledPolicy.Type, expectedPolicy.Type); Assert.Equal(enabledPolicy.Enabled, expectedPolicy.Enabled); } + + [Theory] + [BitAutoData] + public async Task PutVNext_UsesVNextSavePolicyCommand( + SutProvider sutProvider, Guid orgId, + SavePolicyRequest model, Policy policy, Guid userId) + { + // Arrange + policy.Data = null; + + sutProvider.GetDependency() + .UserId + .Returns(userId); + + sutProvider.GetDependency() + .OrganizationOwner(orgId) + .Returns(true); + + sutProvider.GetDependency() + .SaveAsync(Arg.Any()) + .Returns(policy); + + // Act + var result = await sutProvider.Sut.PutVNext(orgId, policy.Type, model); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => m.PolicyUpdate.OrganizationId == orgId && + m.PolicyUpdate.Type == policy.Type && + m.PolicyUpdate.Enabled == model.Policy.Enabled && + m.PerformedBy.UserId == userId && + m.PerformedBy.IsOrganizationOwnerOrProvider == true)); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .VNextSaveAsync(default); + + Assert.NotNull(result); + Assert.Equal(policy.Id, result.Id); + Assert.Equal(policy.Type, result.Type); + } } diff --git a/test/Api.Test/Dirt/Controllers/OrganizationIntegrationControllerTests.cs b/test/Api.Test/Dirt/Controllers/OrganizationIntegrationControllerTests.cs new file mode 100644 index 0000000000..85f4e7ca7f --- /dev/null +++ b/test/Api.Test/Dirt/Controllers/OrganizationIntegrationControllerTests.cs @@ -0,0 +1,211 @@ +using Bit.Api.Dirt.Controllers; +using Bit.Api.Dirt.Models.Request; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Dirt.Controllers; + +[ControllerCustomize(typeof(OrganizationIntegrationController))] +[SutProviderCustomize] +public class OrganizationIntegrationControllerTests +{ + private readonly OrganizationIntegrationRequestModel _webhookRequestModel = new() + { + Configuration = null, + Type = IntegrationType.Webhook + }; + + [Theory, BitAutoData] + public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(() => sutProvider.Sut.GetAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task GetAsync_IntegrationsExist_ReturnsIntegrations( + SutProvider sutProvider, + Guid organizationId, + List integrations) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns(integrations); + + var result = await sutProvider.Sut.GetAsync(organizationId); + + await sutProvider.GetDependency().Received(1) + .GetManyByOrganizationAsync(organizationId); + + Assert.Equal(integrations.Count, result.Count); + Assert.All(result, r => Assert.IsType(r)); + } + + [Theory, BitAutoData] + public async Task GetAsync_NoIntegrations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + + var result = await sutProvider.Sut.GetAsync(organizationId); + + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + + var response = await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Is(i => + i.OrganizationId == organizationId && + i.Type == IntegrationType.Webhook)); + Assert.IsType(response); + } + + [Theory, BitAutoData] + public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(organizationId, _webhookRequestModel)); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId); + } + + [Theory, BitAutoData] + [Obsolete("Obsolete")] + public async Task PostDeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + + await sutProvider.Sut.PostDeleteAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.DeleteAsync(organizationId, integrationId)); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Id = integrationId; + integration.Type = IntegrationType.Webhook; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .UpdateAsync(organizationId, integrationId, Arg.Any()) + .Returns(integration); + + var response = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, _webhookRequestModel); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(organizationId, integrationId, Arg.Is(i => + i.OrganizationId == organizationId && + i.Type == IntegrationType.Webhook)); + Assert.IsType(response); + Assert.Equal(IntegrationType.Webhook, response.Type); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, _webhookRequestModel)); + } +} diff --git a/test/Api.Test/Dirt/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs b/test/Api.Test/Dirt/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs new file mode 100644 index 0000000000..ec8e5c3e36 --- /dev/null +++ b/test/Api.Test/Dirt/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs @@ -0,0 +1,211 @@ +using Bit.Api.Dirt.Controllers; +using Bit.Api.Dirt.Models.Request; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Exceptions; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Dirt.Controllers; + +[ControllerCustomize(typeof(OrganizationIntegrationConfigurationController))] +[SutProviderCustomize] +public class OrganizationIntegrationsConfigurationControllerTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); + } + + [Theory, BitAutoData] + [Obsolete("Obsolete")] + public async Task PostDeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + + await sutProvider.Sut.PostDeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationId, integrationId, configurationId); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + } + + [Theory, BitAutoData] + public async Task GetAsync_ConfigurationsExist_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + List configurations) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); + + Assert.NotNull(result); + Assert.Equal(configurations.Count, result.Count); + Assert.All(result, r => Assert.IsType(r)); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); + } + + [Theory, BitAutoData] + public async Task GetAsync_NoConfigurationsExist_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(organizationId, integrationId) + .Returns([]); + + var result = await sutProvider.Sut.GetAsync(organizationId, integrationId); + + Assert.NotNull(result); + Assert.Empty(result); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(organizationId, integrationId); + } + + [Theory, BitAutoData] + public async Task GetAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.GetAsync(organizationId, integrationId)); + } + + [Theory, BitAutoData] + public async Task PostAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration, + OrganizationIntegrationConfigurationRequestModel model) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .CreateAsync(organizationId, integrationId, Arg.Any()) + .Returns(configuration); + + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, integrationId, model); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(organizationId, integrationId, Arg.Any()); + Assert.IsType(createResponse); + } + + [Theory, BitAutoData] + public async Task PostAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(organizationId, integrationId, new OrganizationIntegrationConfigurationRequestModel())); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration configuration, + OrganizationIntegrationConfigurationRequestModel model) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()) + .Returns(configuration); + + var updateResponse = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, model); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(organizationId, integrationId, configurationId, Arg.Any()); + Assert.IsType(updateResponse); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, new OrganizationIntegrationConfigurationRequestModel())); + } +} diff --git a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs b/test/Api.Test/Dirt/Controllers/SlackIntegrationControllerTests.cs similarity index 85% rename from test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs rename to test/Api.Test/Dirt/Controllers/SlackIntegrationControllerTests.cs index 376fb01493..a8dcfc3395 100644 --- a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs +++ b/test/Api.Test/Dirt/Controllers/SlackIntegrationControllerTests.cs @@ -1,13 +1,13 @@ #nullable enable -using Bit.Api.AdminConsole.Controllers; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Api.Dirt.Controllers; using Bit.Core.Context; -using Bit.Core.Enums; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; @@ -16,7 +16,7 @@ using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; -namespace Bit.Api.Test.AdminConsole.Controllers; +namespace Bit.Api.Test.Dirt.Controllers; [ControllerCustomize(typeof(SlackIntegrationController))] [SutProviderCustomize] @@ -34,7 +34,7 @@ public class SlackIntegrationControllerTests integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -60,7 +60,7 @@ public class SlackIntegrationControllerTests integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .GetByIdAsync(integration.Id) @@ -71,6 +71,26 @@ public class SlackIntegrationControllerTests await sutProvider.Sut.CreateAsync(string.Empty, state.ToString())); } + [Theory, BitAutoData] + public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + [Theory, BitAutoData] public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest( SutProvider sutProvider, @@ -80,7 +100,7 @@ public class SlackIntegrationControllerTests integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .GetByIdAsync(integration.Id) @@ -99,13 +119,13 @@ public class SlackIntegrationControllerTests { sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) .Returns(_slackToken); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, String.Empty)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, string.Empty)); } [Theory, BitAutoData] @@ -116,7 +136,7 @@ public class SlackIntegrationControllerTests var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc)); sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -135,7 +155,7 @@ public class SlackIntegrationControllerTests { sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -147,16 +167,18 @@ public class SlackIntegrationControllerTests } [Theory, BitAutoData] - public async Task CreateAsync_StateHasWrongOgranizationHash_ThrowsNotFound( + public async Task CreateAsync_StateHasWrongOrganizationHash_ThrowsNotFound( SutProvider sutProvider, OrganizationIntegration integration, OrganizationIntegration wrongOrgIntegration) { wrongOrgIntegration.Id = integration.Id; + wrongOrgIntegration.Type = IntegrationType.Slack; + wrongOrgIntegration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -179,7 +201,7 @@ public class SlackIntegrationControllerTests integration.Configuration = "{}"; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -201,7 +223,7 @@ public class SlackIntegrationControllerTests integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns("https://localhost"); sutProvider.GetDependency() .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) @@ -224,7 +246,7 @@ public class SlackIntegrationControllerTests sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(integration.OrganizationId) @@ -260,7 +282,7 @@ public class SlackIntegrationControllerTests sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) @@ -291,7 +313,7 @@ public class SlackIntegrationControllerTests sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) @@ -304,6 +326,22 @@ public class SlackIntegrationControllerTests await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); } + [Theory, BitAutoData] + public async Task RedirectAsync_CallbackUrlReturnsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + [Theory, BitAutoData] public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound( SutProvider sutProvider, @@ -316,7 +354,7 @@ public class SlackIntegrationControllerTests sutProvider.Sut.Url = Substitute.For(); sutProvider.Sut.Url - .RouteUrl(Arg.Is(c => c.RouteName == nameof(SlackIntegrationController.CreateAsync))) + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) diff --git a/test/Api.Test/Dirt/Controllers/TeamsIntegrationControllerTests.cs b/test/Api.Test/Dirt/Controllers/TeamsIntegrationControllerTests.cs new file mode 100644 index 0000000000..b7e778339b --- /dev/null +++ b/test/Api.Test/Dirt/Controllers/TeamsIntegrationControllerTests.cs @@ -0,0 +1,436 @@ +#nullable enable + +using Bit.Api.Dirt.Controllers; +using Bit.Core.Context; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Teams; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Exceptions; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Routing; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.Time.Testing; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Dirt.Controllers; + +[ControllerCustomize(typeof(TeamsIntegrationController))] +[SutProviderCustomize] +public class TeamsIntegrationControllerTests +{ + private const string _teamsToken = "test-token"; + private const string _validTeamsCode = "A_test_code"; + + [Theory, BitAutoData] + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetJoinedTeamsAsync(_teamsToken) + .Returns([ + new TeamInfo() { DisplayName = "Test Team", Id = Guid.NewGuid().ToString(), TenantId = Guid.NewGuid().ToString() } + ]); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + var requestAction = await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString()); + + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Any()); + Assert.IsType(requestAction); + } + + [Theory, BitAutoData] + public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(string.Empty, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_NoTeamsFound_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetJoinedTeamsAsync(_teamsToken) + .Returns([]); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_TeamsServiceReturnsEmptyToken_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(string.Empty); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateEmpty_ThrowsNotFound( + SutProvider sutProvider) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, string.Empty)); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateExpired_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc)); + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + timeProvider.Advance(TimeSpan.FromMinutes(30)); + + sutProvider.SetDependency(timeProvider); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonexistentIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasWrongOrganizationHash_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration wrongOrgIntegration) + { + wrongOrgIntegration.Id = integration.Id; + wrongOrgIntegration.Type = IntegrationType.Teams; + wrongOrgIntegration.Configuration = null; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(wrongOrgIntegration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonEmptyIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = "{}"; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonTeamsIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Hec; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_Success( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(integration.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + var requestAction = await sutProvider.Sut.RedirectAsync(integration.OrganizationId); + + Assert.IsType(requestAction); + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Any()); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithNullConfig_Success( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Teams; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var requestAction = await sutProvider.Sut.RedirectAsync(organizationId); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + Assert.IsType(requestAction); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Teams; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = "{}"; + integration.Type = IntegrationType.Teams; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_TeamsServiceReturnsEmpty_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(string.Empty); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, + Guid organizationId) + { + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task IncomingPostAsync_ForwardsToBot(SutProvider sutProvider) + { + var adapter = sutProvider.GetDependency(); + var bot = sutProvider.GetDependency(); + + await sutProvider.Sut.IncomingPostAsync(); + await adapter.Received(1).ProcessAsync(Arg.Any(), Arg.Any(), bot); + } +} diff --git a/test/Api.Test/Dirt/HibpControllerTests.cs b/test/Api.Test/Dirt/HibpControllerTests.cs new file mode 100644 index 0000000000..9be8d56eae --- /dev/null +++ b/test/Api.Test/Dirt/HibpControllerTests.cs @@ -0,0 +1,292 @@ +using System.Net; +using System.Reflection; +using Bit.Api.Dirt.Controllers; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Api.Test.Dirt; + +[ControllerCustomize(typeof(HibpController))] +[SutProviderCustomize] +public class HibpControllerTests : IDisposable +{ + private readonly HttpClient _originalHttpClient; + private readonly FieldInfo _httpClientField; + + public HibpControllerTests() + { + // Store original HttpClient for restoration + _httpClientField = typeof(HibpController).GetField("_httpClient", BindingFlags.Static | BindingFlags.NonPublic); + _originalHttpClient = (HttpClient)_httpClientField?.GetValue(null); + } + + public void Dispose() + { + // Restore original HttpClient after tests + _httpClientField?.SetValue(null, _originalHttpClient); + } + + [Theory, BitAutoData] + public async Task Get_WithMissingApiKey_ThrowsBadRequestException( + SutProvider sutProvider, + string username) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = null; + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Equal("HaveIBeenPwned API key not set.", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_WithValidApiKeyAndNoBreaches_Returns200WithEmptyArray( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + var user = new User { Id = userId }; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + // Mock HttpClient to return 404 (no breaches found) + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.NotFound, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + var contentResult = Assert.IsType(result); + Assert.Equal("[]", contentResult.Content); + Assert.Equal("application/json", contentResult.ContentType); + } + + [Theory, BitAutoData] + public async Task Get_WithValidApiKeyAndBreachesFound_Returns200WithBreachData( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var breachData = "[{\"Name\":\"Adobe\",\"Title\":\"Adobe\",\"Domain\":\"adobe.com\"}]"; + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.OK, breachData); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + var contentResult = Assert.IsType(result); + Assert.Equal(breachData, contentResult.Content); + Assert.Equal("application/json", contentResult.ContentType); + } + + [Theory, BitAutoData] + public async Task Get_WithRateLimiting_RetriesWithDelay( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + // First response is rate limited, second is success + var requestCount = 0; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + requestCount++; + if (requestCount == 1) + { + var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests); + response.Headers.Add("retry-after", "1"); + return Task.FromResult(response); + } + else + { + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + } + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + var result = await sutProvider.Sut.Get(username); + + // Assert + Assert.Equal(2, requestCount); // Verify retry happened + var contentResult = Assert.IsType(result); + Assert.Equal("[]", contentResult.Content); + } + + [Theory, BitAutoData] + public async Task Get_WithServerError_ThrowsBadRequestException( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.InternalServerError, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Contains("Request failed. Status code:", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_WithBadRequest_ThrowsBadRequestException( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + var mockHttpClient = CreateMockHttpClient(HttpStatusCode.BadRequest, ""); + _httpClientField.SetValue(null, mockHttpClient); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + async () => await sutProvider.Sut.Get(username)); + Assert.Contains("Request failed. Status code:", exception.Message); + } + + [Theory, BitAutoData] + public async Task Get_EncodesUsernameCorrectly( + SutProvider sutProvider, + Guid userId) + { + // Arrange + var usernameWithSpecialChars = "test+user@example.com"; + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + string capturedUrl = null; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + capturedUrl = request.RequestUri.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + await sutProvider.Sut.Get(usernameWithSpecialChars); + + // Assert + Assert.NotNull(capturedUrl); + // Username should be URL encoded (+ becomes %2B, @ becomes %40) + Assert.Contains("test%2Buser%40example.com", capturedUrl); + } + + [Theory, BitAutoData] + public async Task SendAsync_IncludesRequiredHeaders( + SutProvider sutProvider, + string username, + Guid userId) + { + // Arrange + sutProvider.GetDependency().HibpApiKey = "test-api-key"; + sutProvider.GetDependency().SelfHosted = false; + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(userId); + + HttpRequestMessage capturedRequest = null; + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + capturedRequest = request; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + { + Content = new StringContent("") + }); + }); + + var mockHttpClient = new HttpClient(mockHandler); + _httpClientField.SetValue(null, mockHttpClient); + + // Act + await sutProvider.Sut.Get(username); + + // Assert + Assert.NotNull(capturedRequest); + Assert.True(capturedRequest.Headers.Contains("hibp-api-key")); + Assert.True(capturedRequest.Headers.Contains("hibp-client-id")); + Assert.True(capturedRequest.Headers.Contains("User-Agent")); + Assert.Equal("Bitwarden", capturedRequest.Headers.GetValues("User-Agent").First()); + } + + /// + /// Helper to create a mock HttpClient that returns a specific status code and content + /// + private HttpClient CreateMockHttpClient(HttpStatusCode statusCode, string content) + { + var mockHandler = new MockHttpMessageHandler((request, cancellationToken) => + { + return Task.FromResult(new HttpResponseMessage(statusCode) + { + Content = new StringContent(content) + }); + }); + + return new HttpClient(mockHandler); + } +} + +/// +/// Mock HttpMessageHandler for testing HttpClient behavior +/// +public class MockHttpMessageHandler : HttpMessageHandler +{ + private readonly Func> _sendAsync; + + public MockHttpMessageHandler(Func> sendAsync) + { + _sendAsync = sendAsync; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return _sendAsync(request, cancellationToken); + } +} + diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs b/test/Api.Test/Dirt/Models/Request/OrganizationIntegrationRequestModelTests.cs similarity index 76% rename from test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs rename to test/Api.Test/Dirt/Models/Request/OrganizationIntegrationRequestModelTests.cs index 81927a1bfe..190eae260c 100644 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs +++ b/test/Api.Test/Dirt/Models/Request/OrganizationIntegrationRequestModelTests.cs @@ -1,14 +1,47 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; +using Bit.Api.Dirt.Models.Request; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations; +namespace Bit.Api.Test.Dirt.Models.Request; public class OrganizationIntegrationRequestModelTests { + [Fact] + public void ToOrganizationIntegration_CreatesNewOrganizationIntegration() + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Hec, + Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token")) + }; + + var organizationId = Guid.NewGuid(); + var organizationIntegration = model.ToOrganizationIntegration(organizationId); + + Assert.Equal(organizationIntegration.Type, model.Type); + Assert.Equal(organizationIntegration.Configuration, model.Configuration); + Assert.Equal(organizationIntegration.OrganizationId, organizationId); + } + + [Theory, BitAutoData] + public void ToOrganizationIntegration_UpdatesExistingOrganizationIntegration(OrganizationIntegration integration) + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Hec, + Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token")) + }; + + var organizationIntegration = model.ToOrganizationIntegration(integration); + + Assert.Equal(organizationIntegration.Configuration, model.Configuration); + } + [Fact] public void Validate_CloudBillingSync_ReturnsNotYetSupportedError() { @@ -57,6 +90,22 @@ public class OrganizationIntegrationRequestModelTests Assert.Contains("cannot be created directly", results[0].ErrorMessage); } + [Fact] + public void Validate_Teams_ReturnsCannotBeCreatedDirectlyError() + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Teams, + Configuration = null + }; + + var results = model.Validate(new ValidationContext(model)).ToList(); + + Assert.Single(results); + Assert.Contains(nameof(model.Type), results[0].MemberNames); + Assert.Contains("cannot be created directly", results[0].ErrorMessage); + } + [Fact] public void Validate_Webhook_WithNullConfiguration_ReturnsNoErrors() { diff --git a/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs b/test/Api.Test/Dirt/Models/Response/OrganizationIntegrationResponseModelTests.cs similarity index 67% rename from test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs rename to test/Api.Test/Dirt/Models/Response/OrganizationIntegrationResponseModelTests.cs index babdf3894d..e6f8d5d756 100644 --- a/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs +++ b/test/Api.Test/Dirt/Models/Response/OrganizationIntegrationResponseModelTests.cs @@ -1,12 +1,15 @@ #nullable enable -using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Enums; +using System.Text.Json; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Teams; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Api.Test.AdminConsole.Models.Response.Organizations; +namespace Bit.Api.Test.Dirt.Models.Response; public class OrganizationIntegrationResponseModelTests { @@ -58,6 +61,46 @@ public class OrganizationIntegrationResponseModelTests Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); } + [Theory, BitAutoData] + public void Status_Teams_NullConfig_ReturnsInitiated(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Initiated, model.Status); + } + + [Theory, BitAutoData] + public void Status_Teams_WithTenantAndTeamsConfig_ReturnsInProgress(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = JsonSerializer.Serialize(new TeamsIntegration( + TenantId: "tenant", Teams: [new TeamInfo() { DisplayName = "Team", Id = "TeamId", TenantId = "tenant" }] + )); + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.InProgress, model.Status); + } + + [Theory, BitAutoData] + public void Status_Teams_WithCompletedConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = JsonSerializer.Serialize(new TeamsIntegration( + TenantId: "tenant", + Teams: [new TeamInfo() { DisplayName = "Team", Id = "TeamId", TenantId = "tenant" }], + ServiceUrl: new Uri("https://example.com"), + ChannelId: "channellId" + )); + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + [Theory, BitAutoData] public void Status_Webhook_AlwaysCompleted(OrganizationIntegration oi) { diff --git a/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs b/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs index c786fd1c1b..880be1e4d9 100644 --- a/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs +++ b/test/Api.Test/Dirt/OrganizationReportsControllerTests.cs @@ -1,4 +1,5 @@ using Bit.Api.Dirt.Controllers; +using Bit.Api.Dirt.Models.Response; using Bit.Core.Context; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; @@ -39,7 +40,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -262,7 +264,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -365,7 +368,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -597,7 +601,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -812,7 +817,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] @@ -1050,7 +1056,8 @@ public class OrganizationReportControllerTests // Assert var okResult = Assert.IsType(result); - Assert.Equal(expectedReport, okResult.Value); + var expectedResponse = new OrganizationReportResponseModel(expectedReport); + Assert.Equivalent(expectedResponse, okResult.Value); } [Theory, BitAutoData] diff --git a/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index 05b1aa5a4d..c843d24bc3 100644 --- a/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.Test/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -14,7 +14,9 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Commands.Interfaces; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Repositories; using Bit.Core.Services; @@ -110,6 +112,7 @@ public class AccountsKeyManagementControllerTests public async Task RotateUserAccountKeysSuccess(SutProvider sutProvider, RotateUserAccountKeysAndDataRequestModel data, User user) { + data.AccountKeys.SignatureKeyPair = null; sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); sutProvider.GetDependency().RotateUserAccountKeysAsync(Arg.Any(), Arg.Any()) .Returns(IdentityResult.Success); @@ -142,8 +145,60 @@ public class AccountsKeyManagementControllerTests && d.MasterPasswordUnlockData.MasterKeyAuthenticationHash == data.AccountUnlockData.MasterPasswordUnlockData.MasterKeyAuthenticationHash && d.MasterPasswordUnlockData.MasterKeyEncryptedUserKey == data.AccountUnlockData.MasterPasswordUnlockData.MasterKeyEncryptedUserKey - && d.AccountPublicKey == data.AccountKeys.AccountPublicKey - && d.UserKeyEncryptedAccountPrivateKey == data.AccountKeys.UserKeyEncryptedAccountPrivateKey + && d.AccountKeys!.PublicKeyEncryptionKeyPairData.WrappedPrivateKey == data.AccountKeys.PublicKeyEncryptionKeyPair!.WrappedPrivateKey + && d.AccountKeys!.PublicKeyEncryptionKeyPairData.PublicKey == data.AccountKeys.PublicKeyEncryptionKeyPair!.PublicKey + )); + } + + [Theory] + [BitAutoData] + public async Task RotateUserAccountKeys_UserCryptoV2_Success_Async(SutProvider sutProvider, + RotateUserAccountKeysAndDataRequestModel data, User user) + { + data.AccountKeys.SignatureKeyPair = new SignatureKeyPairRequestModel + { + SignatureAlgorithm = "ed25519", + WrappedSigningKey = "wrappedSigningKey", + VerifyingKey = "verifyingKey" + }; + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); + sutProvider.GetDependency().RotateUserAccountKeysAsync(Arg.Any(), Arg.Any()) + .Returns(IdentityResult.Success); + await sutProvider.Sut.RotateUserAccountKeysAsync(data); + + await sutProvider.GetDependency, IEnumerable>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountUnlockData.EmergencyAccessUnlockData)); + await sutProvider.GetDependency, IReadOnlyList>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountUnlockData.OrganizationAccountRecoveryUnlockData)); + await sutProvider.GetDependency, IEnumerable>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountUnlockData.PasskeyUnlockData)); + + await sutProvider.GetDependency, IEnumerable>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountData.Ciphers)); + await sutProvider.GetDependency, IEnumerable>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountData.Folders)); + await sutProvider.GetDependency, IReadOnlyList>>().Received(1) + .ValidateAsync(Arg.Any(), Arg.Is(data.AccountData.Sends)); + + await sutProvider.GetDependency().Received(1) + .RotateUserAccountKeysAsync(Arg.Is(user), Arg.Is(d => + d.OldMasterKeyAuthenticationHash == data.OldMasterKeyAuthenticationHash + + && d.MasterPasswordUnlockData.KdfType == data.AccountUnlockData.MasterPasswordUnlockData.KdfType + && d.MasterPasswordUnlockData.KdfIterations == data.AccountUnlockData.MasterPasswordUnlockData.KdfIterations + && d.MasterPasswordUnlockData.KdfMemory == data.AccountUnlockData.MasterPasswordUnlockData.KdfMemory + && d.MasterPasswordUnlockData.KdfParallelism == data.AccountUnlockData.MasterPasswordUnlockData.KdfParallelism + && d.MasterPasswordUnlockData.Email == data.AccountUnlockData.MasterPasswordUnlockData.Email + + && d.MasterPasswordUnlockData.MasterKeyAuthenticationHash == data.AccountUnlockData.MasterPasswordUnlockData.MasterKeyAuthenticationHash + && d.MasterPasswordUnlockData.MasterKeyEncryptedUserKey == data.AccountUnlockData.MasterPasswordUnlockData.MasterKeyEncryptedUserKey + + && d.AccountKeys!.PublicKeyEncryptionKeyPairData.WrappedPrivateKey == data.AccountKeys.PublicKeyEncryptionKeyPair!.WrappedPrivateKey + && d.AccountKeys!.PublicKeyEncryptionKeyPairData.PublicKey == data.AccountKeys.PublicKeyEncryptionKeyPair!.PublicKey + && d.AccountKeys!.PublicKeyEncryptionKeyPairData.SignedPublicKey == data.AccountKeys.PublicKeyEncryptionKeyPair!.SignedPublicKey + && d.AccountKeys!.SignatureKeyPairData!.SignatureAlgorithm == Core.KeyManagement.Enums.SignatureAlgorithm.Ed25519 + && d.AccountKeys!.SignatureKeyPairData.WrappedSigningKey == data.AccountKeys.SignatureKeyPair!.WrappedSigningKey + && d.AccountKeys!.SignatureKeyPairData.VerifyingKey == data.AccountKeys.SignatureKeyPair!.VerifyingKey )); } @@ -153,6 +208,7 @@ public class AccountsKeyManagementControllerTests public async Task RotateUserKeyNoUser_Throws(SutProvider sutProvider, RotateUserAccountKeysAndDataRequestModel data) { + data.AccountKeys.SignatureKeyPair = null; User? user = null; sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); sutProvider.GetDependency().RotateUserAccountKeysAsync(Arg.Any(), Arg.Any()) @@ -165,6 +221,7 @@ public class AccountsKeyManagementControllerTests public async Task RotateUserKeyWrongData_Throws(SutProvider sutProvider, RotateUserAccountKeysAndDataRequestModel data, User user, IdentityErrorDescriber _identityErrorDescriber) { + data.AccountKeys.SignatureKeyPair = null; sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); sutProvider.GetDependency().RotateUserAccountKeysAsync(Arg.Any(), Arg.Any()) .Returns(IdentityResult.Failed(_identityErrorDescriber.PasswordMismatch())); @@ -181,10 +238,13 @@ public class AccountsKeyManagementControllerTests [Theory] [BitAutoData] - public async Task PostSetKeyConnectorKeyAsync_UserNull_Throws( + public async Task PostSetKeyConnectorKeyAsync_V1_UserNull_Throws( SutProvider sutProvider, SetKeyConnectorKeyRequestModel data) { + data.KeyConnectorKeyWrappedUserKey = null; + data.AccountKeys = null; + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).ReturnsNull(); await Assert.ThrowsAsync(() => sutProvider.Sut.PostSetKeyConnectorKeyAsync(data)); @@ -195,10 +255,13 @@ public class AccountsKeyManagementControllerTests [Theory] [BitAutoData] - public async Task PostSetKeyConnectorKeyAsync_SetKeyConnectorKeyFails_ThrowsBadRequestWithErrorResponse( + public async Task PostSetKeyConnectorKeyAsync_V1_SetKeyConnectorKeyFails_ThrowsBadRequestWithErrorResponse( SutProvider sutProvider, SetKeyConnectorKeyRequestModel data, User expectedUser) { + data.KeyConnectorKeyWrappedUserKey = null; + data.AccountKeys = null; + expectedUser.PublicKey = null; expectedUser.PrivateKey = null; sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) @@ -221,17 +284,20 @@ public class AccountsKeyManagementControllerTests Assert.Equal(data.KdfIterations, user.KdfIterations); Assert.Equal(data.KdfMemory, user.KdfMemory); Assert.Equal(data.KdfParallelism, user.KdfParallelism); - Assert.Equal(data.Keys.PublicKey, user.PublicKey); - Assert.Equal(data.Keys.EncryptedPrivateKey, user.PrivateKey); + Assert.Equal(data.Keys!.PublicKey, user.PublicKey); + Assert.Equal(data.Keys!.EncryptedPrivateKey, user.PrivateKey); }), Arg.Is(data.Key), Arg.Is(data.OrgIdentifier)); } [Theory] [BitAutoData] - public async Task PostSetKeyConnectorKeyAsync_SetKeyConnectorKeySucceeds_OkResponse( + public async Task PostSetKeyConnectorKeyAsync_V1_SetKeyConnectorKeySucceeds_OkResponse( SutProvider sutProvider, SetKeyConnectorKeyRequestModel data, User expectedUser) { + data.KeyConnectorKeyWrappedUserKey = null; + data.AccountKeys = null; + expectedUser.PublicKey = null; expectedUser.PrivateKey = null; sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) @@ -251,11 +317,108 @@ public class AccountsKeyManagementControllerTests Assert.Equal(data.KdfIterations, user.KdfIterations); Assert.Equal(data.KdfMemory, user.KdfMemory); Assert.Equal(data.KdfParallelism, user.KdfParallelism); - Assert.Equal(data.Keys.PublicKey, user.PublicKey); - Assert.Equal(data.Keys.EncryptedPrivateKey, user.PrivateKey); + Assert.Equal(data.Keys!.PublicKey, user.PublicKey); + Assert.Equal(data.Keys!.EncryptedPrivateKey, user.PrivateKey); }), Arg.Is(data.Key), Arg.Is(data.OrgIdentifier)); } + [Theory] + [BitAutoData] + public async Task PostSetKeyConnectorKeyAsync_V2_UserNull_Throws( + SutProvider sutProvider) + { + var request = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = "wrapped-user-key", + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = "public-key", + UserKeyEncryptedAccountPrivateKey = "encrypted-private-key" + }, + OrgIdentifier = "test-org" + }; + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).ReturnsNull(); + + await Assert.ThrowsAsync(() => sutProvider.Sut.PostSetKeyConnectorKeyAsync(request)); + + await sutProvider.GetDependency().DidNotReceive() + .SetKeyConnectorKeyForUserAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task PostSetKeyConnectorKeyAsync_V2_Success( + SutProvider sutProvider, + User expectedUser) + { + var request = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = "wrapped-user-key", + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = "public-key", + UserKeyEncryptedAccountPrivateKey = "encrypted-private-key" + }, + OrgIdentifier = "test-org" + }; + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .Returns(expectedUser); + + await sutProvider.Sut.PostSetKeyConnectorKeyAsync(request); + + await sutProvider.GetDependency().Received(1) + .SetKeyConnectorKeyForUserAsync(Arg.Is(expectedUser), + Arg.Do(data => + { + Assert.Equal(request.KeyConnectorKeyWrappedUserKey, data.KeyConnectorKeyWrappedUserKey); + Assert.Equal(request.AccountKeys.AccountPublicKey, data.AccountKeys.AccountPublicKey); + Assert.Equal(request.AccountKeys.UserKeyEncryptedAccountPrivateKey, + data.AccountKeys.UserKeyEncryptedAccountPrivateKey); + Assert.Equal(request.OrgIdentifier, data.OrgIdentifier); + })); + } + + [Theory] + [BitAutoData] + public async Task PostSetKeyConnectorKeyAsync_V2_CommandThrows_PropagatesException( + SutProvider sutProvider, + User expectedUser) + { + var request = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = "wrapped-user-key", + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = "public-key", + UserKeyEncryptedAccountPrivateKey = "encrypted-private-key" + }, + OrgIdentifier = "test-org" + }; + + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .Returns(expectedUser); + sutProvider.GetDependency() + .When(x => x.SetKeyConnectorKeyForUserAsync(Arg.Any(), Arg.Any())) + .Do(_ => throw new BadRequestException("Command failed")); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.PostSetKeyConnectorKeyAsync(request)); + + Assert.Equal("Command failed", exception.Message); + await sutProvider.GetDependency().Received(1) + .SetKeyConnectorKeyForUserAsync(Arg.Is(expectedUser), + Arg.Do(data => + { + Assert.Equal(request.KeyConnectorKeyWrappedUserKey, data.KeyConnectorKeyWrappedUserKey); + Assert.Equal(request.AccountKeys.AccountPublicKey, data.AccountKeys.AccountPublicKey); + Assert.Equal(request.AccountKeys.UserKeyEncryptedAccountPrivateKey, + data.AccountKeys.UserKeyEncryptedAccountPrivateKey); + Assert.Equal(request.OrgIdentifier, data.OrgIdentifier); + })); + } + [Theory] [BitAutoData] public async Task PostConvertToKeyConnectorAsync_UserNull_Throws( @@ -307,4 +470,39 @@ public class AccountsKeyManagementControllerTests await sutProvider.GetDependency().Received(1) .ConvertToKeyConnectorAsync(Arg.Is(expectedUser)); } + + [Theory] + [BitAutoData] + public async Task GetKeyConnectorConfirmationDetailsAsync_NoUser_Throws( + SutProvider sutProvider, string orgSsoIdentifier) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .ReturnsNull(); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetKeyConnectorConfirmationDetailsAsync(orgSsoIdentifier)); + + await sutProvider.GetDependency().ReceivedWithAnyArgs(0) + .Run(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task GetKeyConnectorConfirmationDetailsAsync_Success( + SutProvider sutProvider, User expectedUser, string orgSsoIdentifier) + { + sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()) + .Returns(expectedUser); + sutProvider.GetDependency().Run(orgSsoIdentifier, expectedUser.Id) + .Returns( + new KeyConnectorConfirmationDetails { OrganizationName = "test" } + ); + + var result = await sutProvider.Sut.GetKeyConnectorConfirmationDetailsAsync(orgSsoIdentifier); + + Assert.NotNull(result); + Assert.Equal("test", result.OrganizationName); + await sutProvider.GetDependency().Received(1) + .Run(orgSsoIdentifier, expectedUser.Id); + } } diff --git a/test/Api.Test/KeyManagement/Controllers/UsersControllerTests.cs b/test/Api.Test/KeyManagement/Controllers/UsersControllerTests.cs new file mode 100644 index 0000000000..6e3094234b --- /dev/null +++ b/test/Api.Test/KeyManagement/Controllers/UsersControllerTests.cs @@ -0,0 +1,112 @@ +#nullable enable +using Bit.Api.KeyManagement.Controllers; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Enums; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Xunit; + +namespace Bit.Api.Test.KeyManagement.Controllers; + +[ControllerCustomize(typeof(UsersController))] +[SutProviderCustomize] +[JsonDocumentCustomize] +public class UsersControllerTests +{ + [Theory] + [BitAutoData] + public async Task GetPublicKey_NotFound_ThrowsNotFoundException( + SutProvider sutProvider) + { + sutProvider.GetDependency().GetPublicKeyAsync(Arg.Any()).ReturnsNull(); + await Assert.ThrowsAsync(() => sutProvider.Sut.GetPublicKeyAsync(new Guid())); + } + + [Theory] + [BitAutoData] + public async Task GetPublicKey_ReturnsUserKeyResponseModel( + SutProvider sutProvider, + Guid userId) + { + var publicKey = "publicKey"; + sutProvider.GetDependency().GetPublicKeyAsync(userId).Returns(publicKey); + + var result = await sutProvider.Sut.GetPublicKeyAsync(userId); + Assert.NotNull(result); + Assert.Equal(userId, result.UserId); + Assert.Equal(publicKey, result.PublicKey); + } + + [Theory] + [BitAutoData] + public async Task GetAccountKeys_UserNotFound_ThrowsNotFoundException( + SutProvider sutProvider) + { + sutProvider.GetDependency().GetByIdAsync(Arg.Any()).ReturnsNull(); + await Assert.ThrowsAsync(() => sutProvider.Sut.GetAccountKeysAsync(new Guid())); + } + + [Theory] + [BitAutoData] + public async Task GetAccountKeys_ReturnsPublicUserKeysResponseModel( + SutProvider sutProvider, + Guid userId) + { + var user = new User + { + Id = userId, + PublicKey = "publicKey", + SignedPublicKey = "signedPublicKey", + }; + + sutProvider.GetDependency().GetByIdAsync(userId).Returns(user); + sutProvider.GetDependency() + .Run(user) + .Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData("wrappedPrivateKey", "publicKey", "signedPublicKey"), + SignatureKeyPairData = new SignatureKeyPairData(SignatureAlgorithm.Ed25519, "wrappedSigningKey", "verifyingKey"), + }); + + var result = await sutProvider.Sut.GetAccountKeysAsync(userId); + Assert.NotNull(result); + Assert.Equal("publicKey", result.PublicKey); + Assert.Equal("signedPublicKey", result.SignedPublicKey); + Assert.Equal("verifyingKey", result.VerifyingKey); + } + + [Theory] + [BitAutoData] + public async Task GetAccountKeys_ReturnsPublicUserKeysResponseModel_WithNullVerifyingKey( + SutProvider sutProvider, + Guid userId) + { + var user = new User + { + Id = userId, + PublicKey = "publicKey", + SignedPublicKey = null, + }; + + sutProvider.GetDependency().GetByIdAsync(userId).Returns(user); + sutProvider.GetDependency() + .Run(user) + .Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData("wrappedPrivateKey", "publicKey", null), + SignatureKeyPairData = null, + }); + + var result = await sutProvider.Sut.GetAccountKeysAsync(userId); + Assert.NotNull(result); + Assert.Equal("publicKey", result.PublicKey); + Assert.Null(result.SignedPublicKey); + Assert.Null(result.VerifyingKey); + } +} diff --git a/test/Api.Test/KeyManagement/Models/Request/SetKeyConnectorKeyRequestModelTests.cs b/test/Api.Test/KeyManagement/Models/Request/SetKeyConnectorKeyRequestModelTests.cs new file mode 100644 index 0000000000..95ee743d02 --- /dev/null +++ b/test/Api.Test/KeyManagement/Models/Request/SetKeyConnectorKeyRequestModelTests.cs @@ -0,0 +1,333 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.KeyManagement.Models.Requests; +using Bit.Core; +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Request; +using Xunit; + +namespace Bit.Api.Test.KeyManagement.Models.Request; + +public class SetKeyConnectorKeyRequestModelTests +{ + private const string _wrappedUserKey = "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; + private const string _publicKey = "public-key"; + private const string _privateKey = "private-key"; + private const string _userKey = "user-key"; + private const string _orgIdentifier = "org-identifier"; + + [Fact] + public void Validate_V2Registration_Valid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = _wrappedUserKey, + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = _publicKey, + UserKeyEncryptedAccountPrivateKey = _privateKey + }, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Empty(results); + } + + [Fact] + public void Validate_V2Registration_WrappedUserKeyNotEncryptedString_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = "not-encrypted-string", + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = _publicKey, + UserKeyEncryptedAccountPrivateKey = _privateKey + }, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, + r => r.ErrorMessage == "KeyConnectorKeyWrappedUserKey is not a valid encrypted string."); + } + + [Fact] + public void Validate_V1Registration_Valid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Empty(results); + } + + [Fact] + public void Validate_V1Registration_MissingKey_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = null, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "Key must be supplied."); + } + + [Fact] + public void Validate_V1Registration_MissingKeys_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = null, + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "Keys must be supplied."); + } + + [Fact] + public void Validate_V1Registration_MissingKdf_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = null, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "Kdf must be supplied."); + } + + [Fact] + public void Validate_V1Registration_MissingKdfIterations_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = null, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "KdfIterations must be supplied."); + } + + [Fact] + public void Validate_V1Registration_Argon2id_MissingKdfMemory_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = KdfType.Argon2id, + KdfIterations = AuthConstants.ARGON2_ITERATIONS.Default, + KdfMemory = null, + KdfParallelism = AuthConstants.ARGON2_PARALLELISM.Default, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "KdfMemory must be supplied when Kdf is Argon2id."); + } + + [Fact] + public void Validate_V1Registration_Argon2id_MissingKdfParallelism_Invalid() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + Key = _userKey, + Keys = new KeysRequestModel + { + PublicKey = _publicKey, + EncryptedPrivateKey = _privateKey + }, + Kdf = KdfType.Argon2id, + KdfIterations = AuthConstants.ARGON2_ITERATIONS.Default, + KdfMemory = AuthConstants.ARGON2_MEMORY.Default, + KdfParallelism = null, + OrgIdentifier = _orgIdentifier + }; + + // Act + var results = Validate(model); + + // Assert + Assert.Single(results); + Assert.Contains(results, r => r.ErrorMessage == "KdfParallelism must be supplied when Kdf is Argon2id."); + } + + [Fact] + public void ToKeyConnectorKeysData_EmptyKeyConnectorKeyWrappedUserKey_ThrowsException() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = "", + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = _publicKey, + UserKeyEncryptedAccountPrivateKey = _privateKey + }, + OrgIdentifier = _orgIdentifier + }; + + // Act + var exception = Assert.Throws(() => model.ToKeyConnectorKeysData()); + + // Assert + Assert.Equal("KeyConnectorKeyWrappedUserKey and AccountKeys must be supplied.", exception.Message); + } + + [Fact] + public void ToKeyConnectorKeysData_NullKeyConnectorKeyWrappedUserKey_ThrowsException() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = null, + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = _publicKey, + UserKeyEncryptedAccountPrivateKey = _privateKey + }, + OrgIdentifier = _orgIdentifier + }; + + // Act + var exception = Assert.Throws(() => model.ToKeyConnectorKeysData()); + + // Assert + Assert.Equal("KeyConnectorKeyWrappedUserKey and AccountKeys must be supplied.", exception.Message); + } + + [Fact] + public void ToKeyConnectorKeysData_NullAccountKeys_ThrowsException() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = _wrappedUserKey, + AccountKeys = null, + OrgIdentifier = _orgIdentifier + }; + + // Act + var exception = Assert.Throws(() => model.ToKeyConnectorKeysData()); + + // Assert + Assert.Equal("KeyConnectorKeyWrappedUserKey and AccountKeys must be supplied.", exception.Message); + } + + [Fact] + public void ToKeyConnectorKeysData_Valid_Success() + { + // Arrange + var model = new SetKeyConnectorKeyRequestModel + { + KeyConnectorKeyWrappedUserKey = _wrappedUserKey, + AccountKeys = new AccountKeysRequestModel + { + AccountPublicKey = _publicKey, + UserKeyEncryptedAccountPrivateKey = _privateKey + }, + OrgIdentifier = _orgIdentifier + }; + + // Act + var data = model.ToKeyConnectorKeysData(); + + // Assert + Assert.Equal(_wrappedUserKey, data.KeyConnectorKeyWrappedUserKey); + Assert.Equal(_publicKey, data.AccountKeys.AccountPublicKey); + Assert.Equal(_privateKey, data.AccountKeys.UserKeyEncryptedAccountPrivateKey); + Assert.Equal(_orgIdentifier, data.OrgIdentifier); + } + + private static List Validate(SetKeyConnectorKeyRequestModel model) + { + var results = new List(); + Validator.TryValidateObject(model, new ValidationContext(model), results, true); + return results; + } +} diff --git a/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs b/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs new file mode 100644 index 0000000000..e1e97efce2 --- /dev/null +++ b/test/Api.Test/KeyManagement/Models/Request/SignatureKeyPairRequestModel.cs @@ -0,0 +1,22 @@ +#nullable enable + +using Bit.Core.KeyManagement.Models.Api.Request; +using Xunit; + +namespace Bit.Api.Test.KeyManagement.Models.Request; + +public class SignatureKeyPairRequestModelTests +{ + [Fact] + public void ToSignatureKeyPairData_WrongAlgorithm_Rejects() + { + var model = new SignatureKeyPairRequestModel + { + SignatureAlgorithm = "abc", + WrappedSigningKey = "wrappedKey", + VerifyingKey = "verifyingKey" + }; + + Assert.Throws(() => model.ToSignatureKeyPairData()); + } +} diff --git a/test/Api.Test/Models/Response/SubscriptionResponseModelTests.cs b/test/Api.Test/Models/Response/SubscriptionResponseModelTests.cs new file mode 100644 index 0000000000..051a66bbd3 --- /dev/null +++ b/test/Api.Test/Models/Response/SubscriptionResponseModelTests.cs @@ -0,0 +1,400 @@ +using Bit.Api.Models.Response; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Models.Business; +using Bit.Core.Entities; +using Bit.Core.Models.Business; +using Bit.Test.Common.AutoFixture.Attributes; +using Stripe; +using Xunit; + +namespace Bit.Api.Test.Models.Response; + +public class SubscriptionResponseModelTests +{ + [Theory] + [BitAutoData] + public void Constructor_IncludeMilestone2DiscountTrueMatchingCouponId_ReturnsDiscount( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.True(result.CustomerDiscount.Active); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + Assert.Null(result.CustomerDiscount.AmountOff); + Assert.NotNull(result.CustomerDiscount.AppliesTo); + Assert.Single(result.CustomerDiscount.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_IncludeMilestone2DiscountTrueNonMatchingCouponId_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = "different-coupon-id", // Non-matching coupon ID + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_IncludeMilestone2DiscountFalseMatchingCouponId_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: false); + + // Assert - Should be null because includeMilestone2Discount is false + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_NullCustomerDiscount_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = null + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_AmountOffDiscountMatchingCouponId_ReturnsDiscount( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + Active = true, + PercentOff = null, + AmountOff = 14.00m, // Already converted from cents in BillingCustomerDiscount + AppliesTo = new List() + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Null(result.CustomerDiscount.PercentOff); + Assert.Equal(14.00m, result.CustomerDiscount.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_DefaultIncludeMilestone2DiscountParameter_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + Active = true, + PercentOff = 20m + } + }; + + // Act - Using default parameter (includeMilestone2Discount defaults to false) + var result = new SubscriptionResponseModel(user, subscriptionInfo, license); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_NullDiscountIdIncludeMilestone2DiscountTrue_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = null, // Null discount ID + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_MatchingCouponIdInactiveDiscount_ReturnsNull( + User user, + UserLicense license) + { + // Arrange + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, // Matching coupon ID + Active = false, // Inactive discount + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "product1" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_UserOnly_SetsBasicProperties(User user) + { + // Arrange + user.Storage = 5368709120; // 5 GB in bytes + user.MaxStorageGb = (short)10; + user.PremiumExpirationDate = DateTime.UtcNow.AddMonths(12); + + // Act + var result = new SubscriptionResponseModel(user); + + // Assert + Assert.NotNull(result.StorageName); + Assert.Equal(5.0, result.StorageGb); + Assert.Equal((short)10, result.MaxStorageGb); + Assert.Equal(user.PremiumExpirationDate, result.Expiration); + Assert.Null(result.License); + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_UserAndLicense_IncludesLicense(User user, UserLicense license) + { + // Arrange + user.Storage = 1073741824; // 1 GB in bytes + user.MaxStorageGb = (short)5; + + // Act + var result = new SubscriptionResponseModel(user, license); + + // Assert + Assert.NotNull(result.License); + Assert.Equal(license, result.License); + Assert.Equal(1.0, result.StorageGb); + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_NullStorage_SetsStorageToZero(User user) + { + // Arrange + user.Storage = null; + + // Act + var result = new SubscriptionResponseModel(user); + + // Assert + Assert.Null(result.StorageName); + Assert.Equal(0, result.StorageGb); + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_NullLicense_ExcludesLicense(User user) + { + // Act + var result = new SubscriptionResponseModel(user, null); + + // Assert + Assert.Null(result.License); + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public void Constructor_BothPercentOffAndAmountOffPresent_HandlesEdgeCase( + User user, + UserLicense license) + { + // Arrange - Edge case: Both PercentOff and AmountOff present + // This tests the scenario where Stripe coupon has both discount types + var subscriptionInfo = new SubscriptionInfo + { + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + Active = true, + PercentOff = 25m, + AmountOff = 20.00m, // Already converted from cents + AppliesTo = new List { "prod_premium" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert - Both values should be preserved + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(25m, result.CustomerDiscount.PercentOff); + Assert.Equal(20.00m, result.CustomerDiscount.AmountOff); + Assert.NotNull(result.CustomerDiscount.AppliesTo); + Assert.Single(result.CustomerDiscount.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_WithSubscriptionAndInvoice_MapsAllProperties( + User user, + UserLicense license) + { + // Arrange - Test with Subscription, UpcomingInvoice, and CustomerDiscount + var stripeSubscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically" + }; + + var stripeInvoice = new Invoice + { + AmountDue = 1500, // 1500 cents = $15.00 + Created = DateTime.UtcNow.AddDays(7) + }; + + var subscriptionInfo = new SubscriptionInfo + { + Subscription = new SubscriptionInfo.BillingSubscription(stripeSubscription), + UpcomingInvoice = new SubscriptionInfo.BillingUpcomingInvoice(stripeInvoice), + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + Active = true, + PercentOff = 20m, + AmountOff = null, + AppliesTo = new List { "prod_premium" } + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert - Verify all properties are mapped correctly + Assert.NotNull(result.Subscription); + Assert.Equal("active", result.Subscription.Status); + Assert.Equal(14, result.Subscription.GracePeriod); // charge_automatically = 14 days + + Assert.NotNull(result.UpcomingInvoice); + Assert.Equal(15.00m, result.UpcomingInvoice.Amount); + Assert.NotNull(result.UpcomingInvoice.Date); + + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.True(result.CustomerDiscount.Active); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public void Constructor_WithNullSubscriptionAndInvoice_HandlesNullsGracefully( + User user, + UserLicense license) + { + // Arrange - Test with null Subscription and UpcomingInvoice + var subscriptionInfo = new SubscriptionInfo + { + Subscription = null, + UpcomingInvoice = null, + CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + Active = true, + PercentOff = 20m + } + }; + + // Act + var result = new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount: true); + + // Assert - Null Subscription and UpcomingInvoice should be handled gracefully + Assert.Null(result.Subscription); + Assert.Null(result.UpcomingInvoice); + Assert.NotNull(result.CustomerDiscount); + } +} diff --git a/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs new file mode 100644 index 0000000000..79a339fcba --- /dev/null +++ b/test/Api.Test/SecretsManager/Controllers/SecretVersionsControllerTests.cs @@ -0,0 +1,307 @@ +using Bit.Api.SecretsManager.Controllers; +using Bit.Api.SecretsManager.Models.Request; +using Bit.Core.Auth.Identity; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.SecretsManager.Entities; +using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.SecretsManager.AutoFixture.SecretsFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.SecretsManager.Controllers; + +[ControllerCustomize(typeof(SecretVersionsController))] +[SutProviderCustomize] +[SecretCustomize] +public class SecretVersionsControllerTests +{ + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_SecretNotFound_Throws( + SutProvider sutProvider, + Guid secretId) + { + sutProvider.GetDependency().GetByIdAsync(secretId).Returns((Secret?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secretId)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_NoAccess_Throws( + SutProvider sutProvider, + Secret secret) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(false); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_NoReadAccess_Throws( + SutProvider sutProvider, + Secret secret, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((false, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetVersionsBySecretId_Success( + SutProvider sutProvider, + Secret secret, + List versions, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + foreach (var version in versions) + { + version.SecretId = secret.Id; + } + sutProvider.GetDependency().GetManyBySecretIdAsync(secret.Id).Returns(versions); + + var result = await sutProvider.Sut.GetVersionsBySecretIdAsync(secret.Id); + + Assert.Equal(versions.Count, result.Data.Count()); + await sutProvider.GetDependency().Received(1) + .GetManyBySecretIdAsync(Arg.Is(secret.Id)); + } + + [Theory] + [BitAutoData] + public async Task GetById_VersionNotFound_Throws( + SutProvider sutProvider, + Guid versionId) + { + sutProvider.GetDependency().GetByIdAsync(versionId).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.GetByIdAsync(versionId)); + } + + [Theory] + [BitAutoData] + public async Task GetById_Success( + SutProvider sutProvider, + SecretVersion version, + Secret secret, + Guid userId) + { + version.SecretId = secret.Id; + sutProvider.GetDependency().GetByIdAsync(version.Id).Returns(version); + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + var result = await sutProvider.Sut.GetByIdAsync(version.Id); + + Assert.Equal(version.Id, result.Id); + Assert.Equal(version.SecretId, result.SecretId); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_NoWriteAccess_Throws( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId) + { + version.SecretId = secret.Id; + request.VersionId = version.Id; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_VersionNotFound_Throws( + SutProvider sutProvider, + Secret secret, + RestoreSecretVersionRequestModel request, + Guid userId) + { + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_VersionBelongsToDifferentSecret_Throws( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId) + { + version.SecretId = Guid.NewGuid(); // Different secret + request.VersionId = version.Id; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns(version); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.RestoreVersionAsync(secret.Id, request)); + } + + [Theory] + [BitAutoData] + public async Task RestoreVersion_Success( + SutProvider sutProvider, + Secret secret, + SecretVersion version, + RestoreSecretVersionRequestModel request, + Guid userId, + OrganizationUser organizationUser) + { + version.SecretId = secret.Id; + request.VersionId = version.Id; + var versionValue = version.Value; + organizationUser.OrganizationId = secret.OrganizationId; + organizationUser.UserId = userId; + + sutProvider.GetDependency().GetByIdAsync(secret.Id).Returns(secret); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + sutProvider.GetDependency().GetByIdAsync(request.VersionId).Returns(version); + sutProvider.GetDependency() + .GetByOrganizationAsync(secret.OrganizationId, userId).Returns(organizationUser); + sutProvider.GetDependency().UpdateAsync(Arg.Any()).Returns(x => x.Arg()); + + var result = await sutProvider.Sut.RestoreVersionAsync(secret.Id, request); + + await sutProvider.GetDependency().Received(1) + .UpdateAsync(Arg.Is(s => s.Value == versionValue)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_EmptyIds_Throws( + SutProvider sutProvider) + { + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(new List())); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_VersionNotFound_Throws( + SutProvider sutProvider, + List ids) + { + sutProvider.GetDependency().GetByIdAsync(ids[0]).Returns((SecretVersion?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(ids)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_NoWriteAccess_Throws( + SutProvider sutProvider, + List versions, + Secret secret, + Guid userId) + { + var ids = versions.Select(v => v.Id).ToList(); + foreach (var version in versions) + { + version.SecretId = secret.Id; + sutProvider.GetDependency().GetByIdAsync(version.Id).Returns(version); + } + + sutProvider.GetDependency().GetManyByIds(Arg.Any>()) + .Returns(new List { secret }); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(false); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, false)); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkDeleteAsync(ids)); + } + + [Theory] + [BitAutoData] + public async Task BulkDelete_Success( + SutProvider sutProvider, + List versions, + Secret secret, + Guid userId) + { + var ids = versions.Select(v => v.Id).ToList(); + foreach (var version in versions) + { + version.SecretId = secret.Id; + } + + sutProvider.GetDependency().GetManyByIdsAsync(ids).Returns(versions); + sutProvider.GetDependency().GetManyByIds(Arg.Any>()) + .Returns(new List { secret }); + sutProvider.GetDependency().AccessSecretsManager(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); + sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); + sutProvider.GetDependency().OrganizationAdmin(secret.OrganizationId).Returns(true); + sutProvider.GetDependency().AccessToSecretAsync(secret.Id, userId, default) + .ReturnsForAnyArgs((true, true)); + + await sutProvider.Sut.BulkDeleteAsync(ids); + + await sutProvider.GetDependency().Received(1) + .DeleteManyByIdAsync(Arg.Is>(x => x.SequenceEqual(ids))); + } +} diff --git a/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs index 83a4229f39..51f61ad7c1 100644 --- a/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs +++ b/test/Api.Test/SecretsManager/Controllers/SecretsControllerTests.cs @@ -2,6 +2,7 @@ using Bit.Api.SecretsManager.Controllers; using Bit.Api.SecretsManager.Models.Request; using Bit.Api.Test.SecretsManager.Enums; +using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -244,6 +245,7 @@ public class SecretsControllerTests { data = SetupSecretUpdateRequest(data); SetControllerUser(sutProvider, new Guid()); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); sutProvider.GetDependency() .AuthorizeAsync(Arg.Any(), Arg.Any(), Arg.Any>()).ReturnsForAnyArgs(AuthorizationResult.Success()); @@ -602,6 +604,7 @@ public class SecretsControllerTests { data = SetupSecretUpdateRequest(data, true); + sutProvider.GetDependency().IdentityClientType.Returns(IdentityClientType.ServiceAccount); sutProvider.GetDependency() .AuthorizeAsync(Arg.Any(), Arg.Any(), Arg.Any>()).Returns(AuthorizationResult.Success()); diff --git a/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs b/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs index 78224a8bd8..5d3b7f2fa5 100644 --- a/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs +++ b/test/Api.Test/SecretsManager/Controllers/ServiceAccountsControllerTests.cs @@ -16,7 +16,7 @@ using Bit.Core.SecretsManager.Models.Data; using Bit.Core.SecretsManager.Queries.ServiceAccounts.Interfaces; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -121,7 +121,7 @@ public class ServiceAccountsControllerTests { ArrangeCreateServiceAccountAutoScalingTest(newSlotsRequired, sutProvider, data, organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); await sutProvider.Sut.CreateAsync(organization.Id, data); diff --git a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs index 4908bb6847..9ca641a28e 100644 --- a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs +++ b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs @@ -75,6 +75,7 @@ public class ImportCiphersControllerTests .With(x => x.Ciphers, fixture.Build() .With(c => c.OrganizationId, Guid.NewGuid().ToString()) .With(c => c.FolderId, Guid.NewGuid().ToString()) + .With(c => c.ArchivedDate, (DateTime?)null) .CreateMany(1).ToArray()) .Create(); @@ -92,6 +93,37 @@ public class ImportCiphersControllerTests ); } + [Theory, BitAutoData] + public async Task PostImportIndividual_WithArchivedDate_SavesArchivedDate(User user, + IFixture fixture, SutProvider sutProvider) + { + var archivedDate = DateTime.UtcNow; + sutProvider.GetDependency() + .SelfHosted = false; + + sutProvider.GetDependency() + .GetProperUserId(Arg.Any()) + .Returns(user.Id); + + var request = fixture.Build() + .With(x => x.Ciphers, fixture.Build() + .With(c => c.ArchivedDate, archivedDate) + .With(c => c.FolderId, (string)null) + .CreateMany(1).ToArray()) + .Create(); + + await sutProvider.Sut.PostImport(request); + + await sutProvider.GetDependency() + .Received() + .ImportIntoIndividualVaultAsync( + Arg.Any>(), + Arg.Is>(ciphers => ciphers.First().ArchivedDate == archivedDate), + Arg.Any>>(), + user.Id + ); + } + /**************************** * PostImport - Organization ****************************/ @@ -156,6 +188,7 @@ public class ImportCiphersControllerTests .With(x => x.Ciphers, fixture.Build() .With(c => c.OrganizationId, Guid.NewGuid().ToString()) .With(c => c.FolderId, Guid.NewGuid().ToString()) + .With(c => c.ArchivedDate, (DateTime?)null) .CreateMany(1).ToArray()) .With(y => y.Collections, fixture.Build() .With(c => c.Id, orgIdGuid) @@ -227,6 +260,7 @@ public class ImportCiphersControllerTests .With(x => x.Ciphers, fixture.Build() .With(c => c.OrganizationId, Guid.NewGuid().ToString()) .With(c => c.FolderId, Guid.NewGuid().ToString()) + .With(c => c.ArchivedDate, (DateTime?)null) .CreateMany(1).ToArray()) .With(y => y.Collections, fixture.Build() .With(c => c.Id, orgIdGuid) @@ -291,6 +325,7 @@ public class ImportCiphersControllerTests .With(x => x.Ciphers, fixture.Build() .With(c => c.OrganizationId, Guid.NewGuid().ToString()) .With(c => c.FolderId, Guid.NewGuid().ToString()) + .With(c => c.ArchivedDate, (DateTime?)null) .CreateMany(1).ToArray()) .With(y => y.Collections, fixture.Build() .With(c => c.Id, orgIdGuid) @@ -354,6 +389,7 @@ public class ImportCiphersControllerTests .With(x => x.Ciphers, fixture.Build() .With(c => c.OrganizationId, Guid.NewGuid().ToString()) .With(c => c.FolderId, Guid.NewGuid().ToString()) + .With(c => c.ArchivedDate, (DateTime?)null) .CreateMany(1).ToArray()) .With(y => y.Collections, fixture.Build() .With(c => c.Id, orgIdGuid) @@ -423,6 +459,7 @@ public class ImportCiphersControllerTests Ciphers = fixture.Build() .With(_ => _.OrganizationId, orgId.ToString()) .With(_ => _.FolderId, Guid.NewGuid().ToString()) + .With(_ => _.ArchivedDate, (DateTime?)null) .CreateMany(2).ToArray(), CollectionRelationships = new List>().ToArray(), }; @@ -499,6 +536,7 @@ public class ImportCiphersControllerTests Ciphers = fixture.Build() .With(_ => _.OrganizationId, orgId.ToString()) .With(_ => _.FolderId, Guid.NewGuid().ToString()) + .With(_ => _.ArchivedDate, (DateTime?)null) .CreateMany(2).ToArray(), CollectionRelationships = new List>().ToArray(), }; @@ -578,6 +616,7 @@ public class ImportCiphersControllerTests Ciphers = fixture.Build() .With(_ => _.OrganizationId, orgId.ToString()) .With(_ => _.FolderId, Guid.NewGuid().ToString()) + .With(_ => _.ArchivedDate, (DateTime?)null) .CreateMany(2).ToArray(), CollectionRelationships = new List>().ToArray(), }; @@ -651,6 +690,7 @@ public class ImportCiphersControllerTests Ciphers = fixture.Build() .With(_ => _.OrganizationId, orgId.ToString()) .With(_ => _.FolderId, Guid.NewGuid().ToString()) + .With(_ => _.ArchivedDate, (DateTime?)null) .CreateMany(2).ToArray(), CollectionRelationships = new List>().ToArray(), }; @@ -720,6 +760,7 @@ public class ImportCiphersControllerTests Ciphers = fixture.Build() .With(_ => _.OrganizationId, orgId.ToString()) .With(_ => _.FolderId, Guid.NewGuid().ToString()) + .With(_ => _.ArchivedDate, (DateTime?)null) .CreateMany(2).ToArray(), CollectionRelationships = new List>().ToArray(), }; @@ -765,6 +806,63 @@ public class ImportCiphersControllerTests Arg.Any()); } + [Theory, BitAutoData] + public async Task PostImportOrganization_ThrowsException_WhenAnyCipherIsArchived( + SutProvider sutProvider, + IFixture fixture, + User user + ) + { + var orgId = Guid.NewGuid(); + + sutProvider.GetDependency() + .SelfHosted = false; + sutProvider.GetDependency() + .ImportCiphersLimitation = _organizationCiphersLimitations; + + SetupUserService(sutProvider, user); + + var ciphers = fixture.Build() + .With(_ => _.ArchivedDate, DateTime.UtcNow) + .CreateMany(2).ToArray(); + + var request = new ImportOrganizationCiphersRequestModel + { + Collections = new List().ToArray(), + Ciphers = ciphers, + CollectionRelationships = new List>().ToArray(), + }; + + sutProvider.GetDependency() + .AccessImportExport(Arg.Any()) + .Returns(false); + + sutProvider.GetDependency() + .AuthorizeAsync(Arg.Any(), + Arg.Any>(), + Arg.Is>(reqs => + reqs.Contains(BulkCollectionOperations.ImportCiphers))) + .Returns(AuthorizationResult.Failed()); + + sutProvider.GetDependency() + .AuthorizeAsync(Arg.Any(), + Arg.Any>(), + Arg.Is>(reqs => + reqs.Contains(BulkCollectionOperations.Create))) + .Returns(AuthorizationResult.Success()); + + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(orgId) + .Returns(new List()); + + var exception = await Assert.ThrowsAsync(async () => + { + await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); + }); + + Assert.Equal("You cannot import archived items into an organization.", exception.Message); + } + private static void SetupUserService(SutProvider sutProvider, User user) { // This is a workaround for the NSubstitute issue with ambiguous arguments diff --git a/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs b/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs new file mode 100644 index 0000000000..95fa949bc7 --- /dev/null +++ b/test/Api.Test/Utilities/DiagnosticTools/EventDiagnosticLoggerTests.cs @@ -0,0 +1,221 @@ +using Bit.Api.Dirt.Public.Models; +using Bit.Api.Models.Public.Response; +using Bit.Api.Utilities.DiagnosticTools; +using Bit.Core; +using Bit.Core.Models.Data; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Utilities.DiagnosticTools; + +public class EventDiagnosticLoggerTests +{ + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_FeatureFlagEnabled_LogsInformation( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var request = new EventFilterRequestModel() + { + Start = DateTime.UtcNow.AddMinutes(-3), + End = DateTime.UtcNow, + ActingUserId = Guid.NewGuid(), + ItemId = Guid.NewGuid(), + }; + + var newestEvent = Substitute.For(); + newestEvent.Date.Returns(DateTime.UtcNow); + var middleEvent = Substitute.For(); + middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1)); + var oldestEvent = Substitute.For(); + oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-3)); + + var eventResponses = new List + { + new (newestEvent), + new (middleEvent), + new (oldestEvent) + }; + var response = new PagedListResponseModel(eventResponses, "continuation-token"); + + // Act + logger.LogAggregateData(featureService, organizationId, response, request); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains($"Event count:{eventResponses.Count}") && + o.ToString().Contains($"newest record:{newestEvent.Date:O}") && + o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") && + o.ToString().Contains("HasMore:True") && + o.ToString().Contains($"Start:{request.Start:o}") && + o.ToString().Contains($"End:{request.End:o}") && + o.ToString().Contains($"ActingUserId:{request.ActingUserId}") && + o.ToString().Contains($"ItemId:{request.ItemId}")) + , + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_FeatureFlagDisabled_DoesNotLog( + Guid organizationId, + EventFilterRequestModel request) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false); + + PagedListResponseModel dummy = null; + + // Act + logger.LogAggregateData(featureService, organizationId, dummy, request); + + // Assert + logger.DidNotReceive().Log( + LogLevel.Information, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithPublicResponse_EmptyData_LogsZeroCount( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var request = new EventFilterRequestModel() + { + Start = null, + End = null, + ActingUserId = null, + ItemId = null, + ContinuationToken = null, + }; + var response = new PagedListResponseModel(new List(), null); + + // Act + logger.LogAggregateData(featureService, organizationId, response, request); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains("Event count:0") && + o.ToString().Contains("HasMore:False")), + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_FeatureFlagDisabled_DoesNotLog(Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(false); + + + // Act + logger.LogAggregateData(featureService, organizationId, null, null, null, null); + + // Assert + logger.DidNotReceive().Log( + LogLevel.Information, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_EmptyData_LogsZeroCount( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + Api.Dirt.Models.Response.EventResponseModel[] emptyEvents = []; + + // Act + logger.LogAggregateData(featureService, organizationId, emptyEvents, null, null, null); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains("Event count:0") && + o.ToString().Contains("HasMore:False")), + null, + Arg.Any>()); + } + + [Theory, BitAutoData] + public void LogAggregateData_WithInternalResponse_FeatureFlagEnabled_LogsInformation( + Guid organizationId) + { + // Arrange + var logger = Substitute.For(); + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging).Returns(true); + + var newestEvent = Substitute.For(); + newestEvent.Date.Returns(DateTime.UtcNow); + var middleEvent = Substitute.For(); + middleEvent.Date.Returns(DateTime.UtcNow.AddDays(-1)); + var oldestEvent = Substitute.For(); + oldestEvent.Date.Returns(DateTime.UtcNow.AddDays(-2)); + + var events = new List + { + new (newestEvent), + new (middleEvent), + new (oldestEvent) + }; + + var queryStart = DateTime.UtcNow.AddMinutes(-3); + var queryEnd = DateTime.UtcNow; + const string continuationToken = "continuation-token"; + + // Act + logger.LogAggregateData(featureService, organizationId, events, continuationToken, queryStart, queryEnd); + + // Assert + logger.Received(1).Log( + LogLevel.Information, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains(organizationId.ToString()) && + o.ToString().Contains($"Event count:{events.Count}") && + o.ToString().Contains($"newest record:{newestEvent.Date:O}") && + o.ToString().Contains($"oldest record:{oldestEvent.Date:O}") && + o.ToString().Contains("HasMore:True") && + o.ToString().Contains($"Start:{queryStart:o}") && + o.ToString().Contains($"End:{queryEnd:o}")) + , + null, + Arg.Any>()); + } +} diff --git a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs index 416b92f841..238053464c 100644 --- a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs @@ -79,7 +79,7 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(id, userId).ReturnsForAnyArgs(cipherDetails); sutProvider.GetDependency().GetManyByUserIdCipherIdAsync(userId, id).Returns((ICollection)new List()); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility() } }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility { Id = cipherDetails.OrganizationId.Value } } }); var cipherService = sutProvider.GetDependency(); await sutProvider.Sut.PutCollections_vNext(id, model); @@ -95,7 +95,7 @@ public class CiphersControllerTests sutProvider.GetDependency().GetByIdAsync(id, userId).ReturnsForAnyArgs(cipherDetails); sutProvider.GetDependency().GetManyByUserIdCipherIdAsync(userId, id).Returns((ICollection)new List()); - sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility() } }); + sutProvider.GetDependency().GetOrganizationAbilitiesAsync().Returns(new Dictionary { { cipherDetails.OrganizationId.Value, new OrganizationAbility { Id = cipherDetails.OrganizationId.Value } } }); var result = await sutProvider.Sut.PutCollections_vNext(id, model); @@ -1909,4 +1909,237 @@ public class CiphersControllerTests await Assert.ThrowsAsync(() => sutProvider.Sut.PostPurge(model, organizationId)); } + + [Theory, BitAutoData] + public async Task PutShare_WithNullFolderAndFalseFavorite_UpdatesFieldsCorrectly( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid folderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, folderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }) + }; + + // Clears folder and favorite when sharing + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = null, + Favorite = false, + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = null, + Favorite = false + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Null(result.FolderId); + Assert.False(result.Favorite); + } + + [Theory, BitAutoData] + public async Task PutShare_WithFolderAndFavoriteSet_AddsUserSpecificFields( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid folderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = null, + Favorites = null + }; + + // Sets folder and favorite when sharing + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = folderId.ToString(), + Favorite = true, + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, folderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }), + FolderId = folderId, + Favorite = true + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Equal(folderId, result.FolderId); + Assert.True(result.Favorite); + } + + [Theory, BitAutoData] + public async Task PutShare_UpdateExistingFolderAndFavorite_UpdatesUserSpecificFields( + Guid cipherId, + Guid userId, + Guid organizationId, + Guid oldFolderId, + Guid newFolderId, + SutProvider sutProvider) + { + var user = new User { Id = userId }; + var userIdKey = userId.ToString().ToUpperInvariant(); + + // Existing cipher with old folder and not favorited + var existingCipher = new Cipher + { + Id = cipherId, + UserId = userId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, oldFolderId.ToString().ToUpperInvariant() } }), + Favorites = null + }; + + var model = new CipherShareRequestModel + { + Cipher = new CipherRequestModel + { + Type = CipherType.Login, + OrganizationId = organizationId.ToString(), + Name = "SharedCipher", + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + FolderId = newFolderId.ToString(), // Update to new folder + Favorite = true, // Add favorite + EncryptedFor = userId + }, + CollectionIds = [Guid.NewGuid().ToString()] + }; + + sutProvider.GetDependency() + .GetUserByPrincipalAsync(Arg.Any()) + .Returns(user); + + sutProvider.GetDependency() + .GetByIdAsync(cipherId) + .Returns(existingCipher); + + sutProvider.GetDependency() + .OrganizationUser(organizationId) + .Returns(true); + + var sharedCipher = new CipherDetails + { + Id = cipherId, + OrganizationId = organizationId, + Type = CipherType.Login, + Data = JsonSerializer.Serialize(new { Username = "test", Password = "test" }), + Folders = JsonSerializer.Serialize(new Dictionary { { userIdKey, newFolderId.ToString().ToUpperInvariant() } }), + Favorites = JsonSerializer.Serialize(new Dictionary { { userIdKey, true } }), + FolderId = newFolderId, + Favorite = true + }; + + sutProvider.GetDependency() + .GetByIdAsync(cipherId, userId) + .Returns(sharedCipher); + + sutProvider.GetDependency() + .GetOrganizationAbilitiesAsync() + .Returns(new Dictionary + { + { organizationId, new OrganizationAbility { Id = organizationId } } + }); + + var result = await sutProvider.Sut.PutShare(cipherId, model); + + Assert.Equal(newFolderId, result.FolderId); + Assert.True(result.Favorite); + } } diff --git a/test/Api.Test/Vault/Controllers/SyncControllerTests.cs b/test/Api.Test/Vault/Controllers/SyncControllerTests.cs index 54db1e4053..e6d34592c7 100644 --- a/test/Api.Test/Vault/Controllers/SyncControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/SyncControllerTests.cs @@ -12,13 +12,15 @@ using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tools.Entities; using Bit.Core.Tools.Repositories; -using Bit.Core.Utilities; using Bit.Core.Vault.Entities; using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; @@ -74,6 +76,7 @@ public class SyncControllerTests var policyRepository = sutProvider.GetDependency(); var collectionRepository = sutProvider.GetDependency(); var collectionCipherRepository = sutProvider.GetDependency(); + var userAccountKeysQuery = sutProvider.GetDependency(); // Adjust random data to match required formats / test intentions user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); @@ -98,6 +101,11 @@ public class SyncControllerTests // Setup returns userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); + userAccountKeysQuery.Run(user).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = null, + }); organizationUserRepository .GetManyDetailsByUserAsync(user.Id, OrganizationUserStatusType.Confirmed).Returns(organizationUserDetails); @@ -127,7 +135,6 @@ public class SyncControllerTests // Execute GET var result = await sutProvider.Sut.Get(); - // Asserts // Assert that methods are called var hasEnabledOrgs = organizationUserDetails.Any(o => o.Enabled); @@ -166,6 +173,7 @@ public class SyncControllerTests var policyRepository = sutProvider.GetDependency(); var collectionRepository = sutProvider.GetDependency(); var collectionCipherRepository = sutProvider.GetDependency(); + var userAccountKeysQuery = sutProvider.GetDependency(); // Adjust random data to match required formats / test intentions user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); @@ -189,6 +197,11 @@ public class SyncControllerTests // Setup returns userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); + userAccountKeysQuery.Run(user).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = null, + }); organizationUserRepository .GetManyDetailsByUserAsync(user.Id, OrganizationUserStatusType.Confirmed).Returns(organizationUserDetails); @@ -256,6 +269,7 @@ public class SyncControllerTests var policyRepository = sutProvider.GetDependency(); var collectionRepository = sutProvider.GetDependency(); var collectionCipherRepository = sutProvider.GetDependency(); + var userAccountKeysQuery = sutProvider.GetDependency(); // Adjust random data to match required formats / test intentions user.EquivalentDomains = JsonSerializer.Serialize(userEquivalentDomains); @@ -271,6 +285,10 @@ public class SyncControllerTests providerUserRepository .GetManyDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed).Returns(providerUserDetails); + foreach (var p in providerUserOrganizationDetails) + { + p.SsoConfig = null; + } providerUserRepository .GetManyOrganizationDetailsByUserAsync(user.Id, ProviderUserStatusType.Confirmed) .Returns(providerUserOrganizationDetails); @@ -290,6 +308,12 @@ public class SyncControllerTests twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user).Returns(false); userService.HasPremiumFromOrganization(user).Returns(false); + userAccountKeysQuery.Run(user).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = null, + }); + // Execute GET var result = await sutProvider.Sut.Get(); @@ -311,7 +335,7 @@ public class SyncControllerTests if (matchedProviderUserOrgDetails != null) { - var providerOrgProductType = StaticStore.GetPlan(matchedProviderUserOrgDetails.PlanType).ProductTier; + var providerOrgProductType = MockPlans.Get(matchedProviderUserOrgDetails.PlanType).ProductTier; Assert.Equal(providerOrgProductType, profProviderOrg.ProductTierType); } } @@ -327,6 +351,13 @@ public class SyncControllerTests user.MasterPassword = null; + var userAccountKeysQuery = sutProvider.GetDependency(); + userAccountKeysQuery.Run(user).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = null, + }); + var userService = sutProvider.GetDependency(); userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); @@ -352,6 +383,13 @@ public class SyncControllerTests user.KdfMemory = kdfMemory; user.KdfParallelism = kdfParallelism; + var userAccountKeysQuery = sutProvider.GetDependency(); + userAccountKeysQuery.Run(user).Returns(new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = null, + }); + var userService = sutProvider.GetDependency(); userService.GetUserByPrincipalAsync(Arg.Any()).ReturnsForAnyArgs(user); diff --git a/test/Billing.Test/Billing.Test.csproj b/test/Billing.Test/Billing.Test.csproj index b4ea2938f6..87a1c28ca1 100644 --- a/test/Billing.Test/Billing.Test.csproj +++ b/test/Billing.Test/Billing.Test.csproj @@ -5,8 +5,8 @@ - + @@ -24,27 +24,10 @@ + - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - - - PreserveNewest - PreserveNewest @@ -73,9 +56,6 @@ PreserveNewest - - PreserveNewest - diff --git a/test/Billing.Test/Controllers/BitPayControllerTests.cs b/test/Billing.Test/Controllers/BitPayControllerTests.cs new file mode 100644 index 0000000000..0118009cb7 --- /dev/null +++ b/test/Billing.Test/Controllers/BitPayControllerTests.cs @@ -0,0 +1,391 @@ +using Bit.Billing.Controllers; +using Bit.Billing.Models; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Clients; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using BitPayLight.Models.Invoice; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using Transaction = Bit.Core.Entities.Transaction; + +namespace Bit.Billing.Test.Controllers; + +using static BitPayConstants; + +public class BitPayControllerTests +{ + private readonly GlobalSettings _globalSettings = new(); + private readonly IBitPayClient _bitPayClient = Substitute.For(); + private readonly ITransactionRepository _transactionRepository = Substitute.For(); + private readonly IOrganizationRepository _organizationRepository = Substitute.For(); + private readonly IUserRepository _userRepository = Substitute.For(); + private readonly IProviderRepository _providerRepository = Substitute.For(); + private readonly IMailService _mailService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); + + private readonly IPremiumUserBillingService _premiumUserBillingService = + Substitute.For(); + + private const string _validWebhookKey = "valid-webhook-key"; + private const string _invalidWebhookKey = "invalid-webhook-key"; + + public BitPayControllerTests() + { + var bitPaySettings = new GlobalSettings.BitPaySettings { WebhookKey = _validWebhookKey }; + _globalSettings.BitPay = bitPaySettings; + } + + private BitPayController CreateController() => new( + _globalSettings, + _bitPayClient, + _transactionRepository, + _organizationRepository, + _userRepository, + _providerRepository, + _mailService, + _paymentService, + Substitute.For>(), + _premiumUserBillingService); + + [Fact] + public async Task PostIpn_InvalidKey_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + + var result = await controller.PostIpn(eventModel, _invalidWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid key", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_NullKey_ThrowsException() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + + await Assert.ThrowsAsync(() => controller.PostIpn(eventModel, null!)); + } + + [Fact] + public async Task PostIpn_EmptyKey_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + + var result = await controller.PostIpn(eventModel, string.Empty); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid key", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_NonUsdCurrency_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(currency: "EUR"); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Cannot process non-USD payments", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_NullPosData_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(posData: null!); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid POS data", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_EmptyPosData_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(posData: ""); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid POS data", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_PosDataWithoutAccountCredit_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(posData: "organizationId:550e8400-e29b-41d4-a716-446655440000"); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid POS data", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_PosDataWithoutValidId_BadRequest() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(posData: PosDataKeys.AccountCredit); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var badRequestResult = Assert.IsType(result); + Assert.Equal("Invalid POS data", badRequestResult.Value); + } + + [Fact] + public async Task PostIpn_IncompleteInvoice_Ok() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(status: "paid"); + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var okResult = Assert.IsType(result); + Assert.Equal("Waiting for invoice to be completed", okResult.Value); + } + + [Fact] + public async Task PostIpn_ExistingTransaction_Ok() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var invoice = CreateValidInvoice(); + var existingTransaction = new Transaction { GatewayId = invoice.Id }; + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns(existingTransaction); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + var okResult = Assert.IsType(result); + Assert.Equal("Invoice already processed", okResult.Value); + } + + [Fact] + public async Task PostIpn_ValidOrganizationTransaction_Success() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var organizationId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}"); + var organization = new Organization { Id = organizationId, BillingEmail = "billing@example.com" }; + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null); + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + _paymentService.CreditAccountAsync(organization, Arg.Any()).Returns(true); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + Assert.IsType(result); + await _transactionRepository.Received(1).CreateAsync(Arg.Is(t => + t.OrganizationId == organizationId && + t.Type == TransactionType.Credit && + t.Gateway == GatewayType.BitPay && + t.PaymentMethodType == PaymentMethodType.BitPay)); + await _organizationRepository.Received(1).ReplaceAsync(organization); + await _mailService.Received(1).SendAddedCreditAsync("billing@example.com", 100.00m); + } + + [Fact] + public async Task PostIpn_ValidUserTransaction_Success() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var userId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}"); + var user = new User { Id = userId, Email = "user@example.com" }; + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null); + _userRepository.GetByIdAsync(userId).Returns(user); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + Assert.IsType(result); + await _transactionRepository.Received(1).CreateAsync(Arg.Is(t => + t.UserId == userId && + t.Type == TransactionType.Credit && + t.Gateway == GatewayType.BitPay && + t.PaymentMethodType == PaymentMethodType.BitPay)); + await _premiumUserBillingService.Received(1).Credit(user, 100.00m); + await _mailService.Received(1).SendAddedCreditAsync("user@example.com", 100.00m); + } + + [Fact] + public async Task PostIpn_ValidProviderTransaction_Success() + { + var controller = CreateController(); + var eventModel = CreateValidEventModel(); + var providerId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}"); + var provider = new Provider { Id = providerId, BillingEmail = "provider@example.com" }; + + _bitPayClient.GetInvoice(eventModel.Data.Id).Returns(invoice); + _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id).Returns((Transaction)null); + _providerRepository.GetByIdAsync(providerId).Returns(Task.FromResult(provider)); + _paymentService.CreditAccountAsync(provider, Arg.Any()).Returns(true); + + var result = await controller.PostIpn(eventModel, _validWebhookKey); + + Assert.IsType(result); + await _transactionRepository.Received(1).CreateAsync(Arg.Is(t => + t.ProviderId == providerId && + t.Type == TransactionType.Credit && + t.Gateway == GatewayType.BitPay && + t.PaymentMethodType == PaymentMethodType.BitPay)); + await _providerRepository.Received(1).ReplaceAsync(provider); + await _mailService.Received(1).SendAddedCreditAsync("provider@example.com", 100.00m); + } + + [Fact] + public void GetIdsFromPosData_ValidOrganizationId_ReturnsCorrectId() + { + var controller = CreateController(); + var organizationId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"organizationId:{organizationId},{PosDataKeys.AccountCredit}"); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Equal(organizationId, result.OrganizationId); + Assert.Null(result.UserId); + Assert.Null(result.ProviderId); + } + + [Fact] + public void GetIdsFromPosData_ValidUserId_ReturnsCorrectId() + { + var controller = CreateController(); + var userId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"userId:{userId},{PosDataKeys.AccountCredit}"); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Null(result.OrganizationId); + Assert.Equal(userId, result.UserId); + Assert.Null(result.ProviderId); + } + + [Fact] + public void GetIdsFromPosData_ValidProviderId_ReturnsCorrectId() + { + var controller = CreateController(); + var providerId = Guid.NewGuid(); + var invoice = CreateValidInvoice(posData: $"providerId:{providerId},{PosDataKeys.AccountCredit}"); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Null(result.OrganizationId); + Assert.Null(result.UserId); + Assert.Equal(providerId, result.ProviderId); + } + + [Fact] + public void GetIdsFromPosData_InvalidGuid_ReturnsNull() + { + var controller = CreateController(); + var invoice = CreateValidInvoice(posData: "organizationId:invalid-guid,{PosDataKeys.AccountCredit}"); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Null(result.OrganizationId); + Assert.Null(result.UserId); + Assert.Null(result.ProviderId); + } + + [Fact] + public void GetIdsFromPosData_NullPosData_ReturnsNull() + { + var controller = CreateController(); + var invoice = CreateValidInvoice(posData: null!); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Null(result.OrganizationId); + Assert.Null(result.UserId); + Assert.Null(result.ProviderId); + } + + [Fact] + public void GetIdsFromPosData_EmptyPosData_ReturnsNull() + { + var controller = CreateController(); + var invoice = CreateValidInvoice(posData: ""); + + var result = controller.GetIdsFromPosData(invoice); + + Assert.Null(result.OrganizationId); + Assert.Null(result.UserId); + Assert.Null(result.ProviderId); + } + + private static BitPayEventModel CreateValidEventModel(string invoiceId = "test-invoice-id") + { + return new BitPayEventModel + { + Event = new BitPayEventModel.EventModel { Code = 1005, Name = "invoice_confirmed" }, + Data = new BitPayEventModel.InvoiceDataModel { Id = invoiceId } + }; + } + + private static Invoice CreateValidInvoice(string invoiceId = "test-invoice-id", string status = "complete", + string currency = "USD", decimal price = 100.00m, + string posData = "organizationId:550e8400-e29b-41d4-a716-446655440000,accountCredit:1") + { + return new Invoice + { + Id = invoiceId, + Status = status, + Currency = currency, + Price = (double)price, + PosData = posData, + CurrentTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + Transactions = + [ + new InvoiceTransaction + { + Type = null, + Confirmations = "1", + ReceivedTime = DateTime.UtcNow.ToString("O") + } + ] + }; + } + +} diff --git a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs b/test/Billing.Test/Controllers/FreshdeskControllerTests.cs deleted file mode 100644 index 8fd0769a02..0000000000 --- a/test/Billing.Test/Controllers/FreshdeskControllerTests.cs +++ /dev/null @@ -1,251 +0,0 @@ -using System.Text.Json; -using Bit.Billing.Controllers; -using Bit.Billing.Models; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Entities; -using Bit.Core.Repositories; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using NSubstitute; -using NSubstitute.ReceivedExtensions; -using Xunit; - -namespace Bit.Billing.Test.Controllers; - -[ControllerCustomize(typeof(FreshdeskController))] -[SutProviderCustomize] -public class FreshdeskControllerTests -{ - private const string ApiKey = "TESTFRESHDESKAPIKEY"; - private const string WebhookKey = "TESTKEY"; - - private const string UserFieldName = "cf_user"; - private const string OrgFieldName = "cf_org"; - - [Theory] - [BitAutoData((string)null, null)] - [BitAutoData((string)null)] - [BitAutoData(WebhookKey, null)] - public async Task PostWebhook_NullRequiredParameters_BadRequest(string freshdeskWebhookKey, FreshdeskWebhookModel model, - BillingSettings billingSettings, SutProvider sutProvider) - { - sutProvider.GetDependency>().Value.FreshDesk.WebhookKey.Returns(billingSettings.FreshDesk.WebhookKey); - - var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); - } - - [Theory] - [BitAutoData] - public async Task PostWebhook_Success(User user, FreshdeskWebhookModel model, - List organizations, SutProvider sutProvider) - { - model.TicketContactEmail = user.Email; - - sutProvider.GetDependency().GetByEmailAsync(user.Email).Returns(user); - sutProvider.GetDependency().GetManyByUserIdAsync(user.Id).Returns(organizations); - - var mockHttpMessageHandler = Substitute.ForPartsOf(); - var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); - mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockResponse); - var httpClient = new HttpClient(mockHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); - - sutProvider.GetDependency>().Value.FreshDesk.WebhookKey.Returns(WebhookKey); - sutProvider.GetDependency>().Value.FreshDesk.ApiKey.Returns(ApiKey); - sutProvider.GetDependency>().Value.FreshDesk.UserFieldName.Returns(UserFieldName); - sutProvider.GetDependency>().Value.FreshDesk.OrgFieldName.Returns(OrgFieldName); - - var response = await sutProvider.Sut.PostWebhook(WebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); - - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Put && m.RequestUri.ToString().EndsWith(model.TicketId)), Arg.Any()); - _ = mockHttpMessageHandler.Received(1).Send(Arg.Is(m => m.Method == HttpMethod.Post && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes")), Arg.Any()); - } - - [Theory] - [BitAutoData(WebhookKey)] - public async Task PostWebhook_add_note_when_user_is_invalid( - string freshdeskWebhookKey, FreshdeskWebhookModel model, - SutProvider sutProvider) - { - // Arrange - for an invalid user - model.TicketContactEmail = "invalid@user"; - sutProvider.GetDependency().GetByEmailAsync(model.TicketContactEmail).Returns((User)null); - sutProvider.GetDependency>().Value.FreshDesk.WebhookKey.Returns(WebhookKey); - - var mockHttpMessageHandler = Substitute.ForPartsOf(); - var mockResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK); - mockHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockResponse); - var httpClient = new HttpClient(mockHttpMessageHandler); - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(httpClient); - - // Act - var response = await sutProvider.Sut.PostWebhook(freshdeskWebhookKey, model); - - // Assert - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); - - await mockHttpMessageHandler - .Received(1).Send( - Arg.Is( - m => m.Method == HttpMethod.Post - && m.RequestUri.ToString().EndsWith($"{model.TicketId}/notes") - && m.Content.ReadAsStringAsync().Result.Contains("No user found")), - Arg.Any()); - } - - - [Theory] - [BitAutoData((string)null, null)] - [BitAutoData((string)null)] - [BitAutoData(WebhookKey, null)] - public async Task PostWebhookOnyxAi_InvalidWebhookKey_results_in_BadRequest( - string freshdeskWebhookKey, FreshdeskOnyxAiWebhookModel model, - BillingSettings billingSettings, SutProvider sutProvider) - { - sutProvider.GetDependency>() - .Value.FreshDesk.WebhookKey.Returns(billingSettings.FreshDesk.WebhookKey); - - var response = await sutProvider.Sut.PostWebhookOnyxAi(freshdeskWebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status400BadRequest, statusCodeResult.StatusCode); - } - - [Theory] - [BitAutoData(WebhookKey)] - public async Task PostWebhookOnyxAi_invalid_onyx_response_results_is_logged( - string freshdeskWebhookKey, FreshdeskOnyxAiWebhookModel model, - SutProvider sutProvider) - { - var billingSettings = sutProvider.GetDependency>().Value; - billingSettings.FreshDesk.WebhookKey.Returns(freshdeskWebhookKey); - billingSettings.Onyx.BaseUrl.Returns("http://simulate-onyx-api.com/api"); - - // mocking freshdesk Api request for ticket info - var mockFreshdeskHttpMessageHandler = Substitute.ForPartsOf(); - var freshdeskHttpClient = new HttpClient(mockFreshdeskHttpMessageHandler); - - // mocking Onyx api response given a ticket description - var mockOnyxHttpMessageHandler = Substitute.ForPartsOf(); - var mockOnyxResponse = new HttpResponseMessage(System.Net.HttpStatusCode.BadRequest); - mockOnyxHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockOnyxResponse); - var onyxHttpClient = new HttpClient(mockOnyxHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(freshdeskHttpClient); - sutProvider.GetDependency().CreateClient("OnyxApi").Returns(onyxHttpClient); - - var response = await sutProvider.Sut.PostWebhookOnyxAi(freshdeskWebhookKey, model); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, statusCodeResult.StatusCode); - - var _logger = sutProvider.GetDependency>(); - - // workaround because _logger.Received(1).LogWarning(...) does not work - _logger.ReceivedCalls().Any(c => c.GetMethodInfo().Name == "Log" && c.GetArguments()[1].ToString().Contains("Error getting answer from Onyx AI")); - - // sent call to Onyx API - but we got an error response - _ = mockOnyxHttpMessageHandler.Received(1).Send(Arg.Any(), Arg.Any()); - // did not call freshdesk to add a note since onyx failed - _ = mockFreshdeskHttpMessageHandler.DidNotReceive().Send(Arg.Any(), Arg.Any()); - } - - [Theory] - [BitAutoData(WebhookKey)] - public async Task PostWebhookOnyxAi_success( - string freshdeskWebhookKey, FreshdeskOnyxAiWebhookModel model, - OnyxAnswerWithCitationResponseModel onyxResponse, - SutProvider sutProvider) - { - var billingSettings = sutProvider.GetDependency>().Value; - billingSettings.FreshDesk.WebhookKey.Returns(freshdeskWebhookKey); - billingSettings.Onyx.BaseUrl.Returns("http://simulate-onyx-api.com/api"); - - // mocking freshdesk api add note request (POST) - var mockFreshdeskHttpMessageHandler = Substitute.ForPartsOf(); - var mockFreshdeskAddNoteResponse = new HttpResponseMessage(System.Net.HttpStatusCode.BadRequest); - mockFreshdeskHttpMessageHandler.Send( - Arg.Is(_ => _.Method == HttpMethod.Post), - Arg.Any()) - .Returns(mockFreshdeskAddNoteResponse); - var freshdeskHttpClient = new HttpClient(mockFreshdeskHttpMessageHandler); - - // mocking Onyx api response given a ticket description - var mockOnyxHttpMessageHandler = Substitute.ForPartsOf(); - onyxResponse.ErrorMsg = "string.Empty"; - var mockOnyxResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK) - { - Content = new StringContent(JsonSerializer.Serialize(onyxResponse)) - }; - mockOnyxHttpMessageHandler.Send(Arg.Any(), Arg.Any()) - .Returns(mockOnyxResponse); - var onyxHttpClient = new HttpClient(mockOnyxHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(freshdeskHttpClient); - sutProvider.GetDependency().CreateClient("OnyxApi").Returns(onyxHttpClient); - - var response = await sutProvider.Sut.PostWebhookOnyxAi(freshdeskWebhookKey, model); - - var result = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, result.StatusCode); - } - - [Theory] - [BitAutoData(WebhookKey)] - public async Task PostWebhookOnyxAi_ticket_description_is_empty_return_success( - string freshdeskWebhookKey, FreshdeskOnyxAiWebhookModel model, - SutProvider sutProvider) - { - var billingSettings = sutProvider.GetDependency>().Value; - billingSettings.FreshDesk.WebhookKey.Returns(freshdeskWebhookKey); - billingSettings.Onyx.BaseUrl.Returns("http://simulate-onyx-api.com/api"); - - model.TicketDescriptionText = " "; // empty description - - // mocking freshdesk api add note request (POST) - var mockFreshdeskHttpMessageHandler = Substitute.ForPartsOf(); - var freshdeskHttpClient = new HttpClient(mockFreshdeskHttpMessageHandler); - - // mocking Onyx api response given a ticket description - var mockOnyxHttpMessageHandler = Substitute.ForPartsOf(); - var onyxHttpClient = new HttpClient(mockOnyxHttpMessageHandler); - - sutProvider.GetDependency().CreateClient("FreshdeskApi").Returns(freshdeskHttpClient); - sutProvider.GetDependency().CreateClient("OnyxApi").Returns(onyxHttpClient); - - var response = await sutProvider.Sut.PostWebhookOnyxAi(freshdeskWebhookKey, model); - - var result = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status200OK, result.StatusCode); - _ = mockFreshdeskHttpMessageHandler.DidNotReceive().Send(Arg.Any(), Arg.Any()); - _ = mockOnyxHttpMessageHandler.DidNotReceive().Send(Arg.Any(), Arg.Any()); - } - - public class MockHttpMessageHandler : HttpMessageHandler - { - protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { - return Send(request, cancellationToken); - } - - public new virtual Task Send(HttpRequestMessage request, CancellationToken cancellationToken) - { - throw new NotImplementedException(); - } - } -} diff --git a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs b/test/Billing.Test/Controllers/FreshsalesControllerTests.cs deleted file mode 100644 index c9ae6efb1a..0000000000 --- a/test/Billing.Test/Controllers/FreshsalesControllerTests.cs +++ /dev/null @@ -1,82 +0,0 @@ -using Bit.Billing.Controllers; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Entities; -using Bit.Core.Repositories; -using Bit.Core.Settings; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using NSubstitute; -using Xunit; - -namespace Bit.Billing.Test.Controllers; - -public class FreshsalesControllerTests -{ - private const string ApiKey = "TEST_FRESHSALES_APIKEY"; - private const string TestLead = "TEST_FRESHSALES_TESTLEAD"; - - private static (FreshsalesController, IUserRepository, IOrganizationRepository) CreateSut( - string freshsalesApiKey) - { - var userRepository = Substitute.For(); - var organizationRepository = Substitute.For(); - - var billingSettings = Options.Create(new BillingSettings - { - FreshsalesApiKey = freshsalesApiKey, - }); - var globalSettings = new GlobalSettings(); - globalSettings.BaseServiceUri.Admin = "https://test.com"; - - var sut = new FreshsalesController( - userRepository, - organizationRepository, - billingSettings, - Substitute.For>(), - globalSettings - ); - - return (sut, userRepository, organizationRepository); - } - - [RequiredEnvironmentTheory(ApiKey, TestLead), EnvironmentData(ApiKey, TestLead)] - public async Task PostWebhook_Success(string freshsalesApiKey, long leadId) - { - // This test is only for development to use: - // `export TEST_FRESHSALES_APIKEY=[apikey]` - // `export TEST_FRESHSALES_TESTLEAD=[lead id]` - // `dotnet test --filter "FullyQualifiedName~FreshsalesControllerTests.PostWebhook_Success"` - var (sut, userRepository, organizationRepository) = CreateSut(freshsalesApiKey); - - var user = new User - { - Id = Guid.NewGuid(), - Email = "test@email.com", - Premium = true, - }; - - userRepository.GetByEmailAsync(user.Email) - .Returns(user); - - organizationRepository.GetManyByUserIdAsync(user.Id) - .Returns(new List - { - new Organization - { - Id = Guid.NewGuid(), - Name = "Test Org", - } - }); - - var response = await sut.PostWebhook(freshsalesApiKey, new CustomWebhookRequestModel - { - LeadId = leadId, - }, new CancellationToken(false)); - - var statusCodeResult = Assert.IsAssignableFrom(response); - Assert.Equal(StatusCodes.Status204NoContent, statusCodeResult.StatusCode); - } -} diff --git a/test/Billing.Test/Controllers/PayPalControllerTests.cs b/test/Billing.Test/Controllers/PayPalControllerTests.cs index 7ec17bd85a..da995b6188 100644 --- a/test/Billing.Test/Controllers/PayPalControllerTests.cs +++ b/test/Billing.Test/Controllers/PayPalControllerTests.cs @@ -8,13 +8,13 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; -using Divergic.Logging.Xunit; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Neovolve.Logging.Xunit; using NSubstitute; using NSubstitute.ReturnsExtensions; using Xunit; @@ -23,14 +23,12 @@ using Transaction = Bit.Core.Entities.Transaction; namespace Bit.Billing.Test.Controllers; -public class PayPalControllerTests +public class PayPalControllerTests(ITestOutputHelper testOutputHelper) { - private readonly ITestOutputHelper _testOutputHelper; - private readonly IOptions _billingSettings = Substitute.For>(); private readonly IMailService _mailService = Substitute.For(); private readonly IOrganizationRepository _organizationRepository = Substitute.For(); - private readonly IPaymentService _paymentService = Substitute.For(); + private readonly IStripePaymentService _paymentService = Substitute.For(); private readonly ITransactionRepository _transactionRepository = Substitute.For(); private readonly IUserRepository _userRepository = Substitute.For(); private readonly IProviderRepository _providerRepository = Substitute.For(); @@ -38,15 +36,10 @@ public class PayPalControllerTests private const string _defaultWebhookKey = "webhook-key"; - public PayPalControllerTests(ITestOutputHelper testOutputHelper) - { - _testOutputHelper = testOutputHelper; - } - [Fact] public async Task PostIpn_NullKey_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); var controller = ConfigureControllerContextWith(logger, null, null); @@ -60,7 +53,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_IncorrectKey_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -79,7 +72,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_EmptyIPNBody_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -98,7 +91,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_IPNHasNoEntityId_BadRequest() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -119,15 +112,13 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_OtherTransactionType_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { PayPal = { WebhookKey = _defaultWebhookKey } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.UnsupportedTransactionType); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -142,7 +133,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_MismatchedReceiverID_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -153,8 +144,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -169,7 +158,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_RefundMissingParent_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -180,8 +169,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.RefundMissingParentTransaction); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -196,7 +183,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_eCheckPayment_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -207,8 +194,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.ECheckPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -223,7 +208,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_NonUSD_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -234,8 +219,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.NonUSDPayment); var controller = ConfigureControllerContextWith(logger, _defaultWebhookKey, ipnBody); @@ -250,7 +233,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_ExistingTransaction_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -261,8 +244,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); _transactionRepository.GetByGatewayIdAsync( @@ -281,7 +262,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -292,8 +273,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulPayment); _transactionRepository.GetByGatewayIdAsync( @@ -314,7 +293,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_CreditsOrganizationAccount_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -362,7 +341,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Completed_CreatesTransaction_CreditsUserAccount_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -406,7 +385,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_ExistingTransaction_Unprocessed_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -417,8 +396,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulRefund); _transactionRepository.GetByGatewayIdAsync( @@ -441,7 +418,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_MissingParentTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -452,8 +429,6 @@ public class PayPalControllerTests } }); - var organizationId = new Guid("ca8c6f2b-2d7b-4639-809f-b0e5013a304e"); - var ipnBody = await PayPalTestIPN.GetAsync(IPNBody.SuccessfulRefund); _transactionRepository.GetByGatewayIdAsync( @@ -480,7 +455,7 @@ public class PayPalControllerTests [Fact] public async Task PostIpn_Refunded_ReplacesParent_CreatesTransaction_Ok() { - var logger = _testOutputHelper.BuildLoggerFor(); + var logger = testOutputHelper.BuildLoggerFor(); _billingSettings.Value.Returns(new BillingSettings { @@ -531,8 +506,8 @@ public class PayPalControllerTests private PayPalController ConfigureControllerContextWith( ILogger logger, - string webhookKey, - string ipnBody) + string? webhookKey, + string? ipnBody) { var controller = new PayPalController( _billingSettings, @@ -578,16 +553,16 @@ public class PayPalControllerTests Assert.Equal(statusCode, statusCodeActionResult.StatusCode); } - private static void Logged(ICacheLogger logger, LogLevel logLevel, string message) + private static void Logged(ICacheLogger logger, LogLevel logLevel, string message) { Assert.NotNull(logger.Last); Assert.Equal(logLevel, logger.Last!.LogLevel); Assert.Equal(message, logger.Last!.Message); } - private static void LoggedError(ICacheLogger logger, string message) + private static void LoggedError(ICacheLogger logger, string message) => Logged(logger, LogLevel.Error, message); - private static void LoggedWarning(ICacheLogger logger, string message) + private static void LoggedWarning(ICacheLogger logger, string message) => Logged(logger, LogLevel.Warning, message); } diff --git a/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs b/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs new file mode 100644 index 0000000000..91b38341e5 --- /dev/null +++ b/test/Billing.Test/Jobs/ProviderOrganizationDisableJobTests.cs @@ -0,0 +1,234 @@ +using Bit.Billing.Jobs; +using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Quartz; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class ProviderOrganizationDisableJobTests +{ + private readonly IProviderOrganizationRepository _providerOrganizationRepository; + private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly ILogger _logger; + private readonly ProviderOrganizationDisableJob _sut; + + public ProviderOrganizationDisableJobTests() + { + _providerOrganizationRepository = Substitute.For(); + _organizationDisableCommand = Substitute.For(); + _logger = Substitute.For>(); + _sut = new ProviderOrganizationDisableJob( + _providerOrganizationRepository, + _organizationDisableCommand, + _logger); + } + + [Fact] + public async Task Execute_NoOrganizations_LogsAndReturns() + { + // Arrange + var providerId = Guid.NewGuid(); + var context = CreateJobExecutionContext(providerId, DateTime.UtcNow); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns((ICollection)null); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); + } + + [Fact] + public async Task Execute_WithOrganizations_DisablesAllOrganizations() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var org1Id = Guid.NewGuid(); + var org2Id = Guid.NewGuid(); + var org3Id = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = org1Id }, + new() { OrganizationId = org2Id }, + new() { OrganizationId = org3Id } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any()); + } + + [Fact] + public async Task Execute_WithExpirationDate_PassesDateToDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = new DateTime(2025, 12, 31, 23, 59, 59); + var orgId = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = orgId } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(orgId, expirationDate); + } + + [Fact] + public async Task Execute_WithNullExpirationDate_PassesNullToDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var orgId = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = orgId } + }; + + var context = CreateJobExecutionContext(providerId, null); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.Received(1).DisableAsync(orgId, null); + } + + [Fact] + public async Task Execute_OneOrganizationFails_ContinuesProcessingOthers() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var org1Id = Guid.NewGuid(); + var org2Id = Guid.NewGuid(); + var org3Id = Guid.NewGuid(); + + var organizations = new List + { + new() { OrganizationId = org1Id }, + new() { OrganizationId = org2Id }, + new() { OrganizationId = org3Id } + }; + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + // Make org2 fail + _organizationDisableCommand.DisableAsync(org2Id, Arg.Any()) + .Throws(new Exception("Database error")); + + // Act + await _sut.Execute(context); + + // Assert - all three should be attempted + await _organizationDisableCommand.Received(1).DisableAsync(org1Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org2Id, Arg.Any()); + await _organizationDisableCommand.Received(1).DisableAsync(org3Id, Arg.Any()); + } + + [Fact] + public async Task Execute_ManyOrganizations_ProcessesWithLimitedConcurrency() + { + // Arrange + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + + // Create 20 organizations + var organizations = Enumerable.Range(1, 20) + .Select(_ => new ProviderOrganizationOrganizationDetails { OrganizationId = Guid.NewGuid() }) + .ToList(); + + var context = CreateJobExecutionContext(providerId, expirationDate); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(organizations); + + var concurrentCalls = 0; + var maxConcurrentCalls = 0; + var lockObj = new object(); + + _organizationDisableCommand.DisableAsync(Arg.Any(), Arg.Any()) + .Returns(callInfo => + { + lock (lockObj) + { + concurrentCalls++; + if (concurrentCalls > maxConcurrentCalls) + { + maxConcurrentCalls = concurrentCalls; + } + } + + return Task.Delay(50).ContinueWith(_ => + { + lock (lockObj) + { + concurrentCalls--; + } + }); + }); + + // Act + await _sut.Execute(context); + + // Assert + Assert.True(maxConcurrentCalls <= 5, $"Expected max concurrency of 5, but got {maxConcurrentCalls}"); + await _organizationDisableCommand.Received(20).DisableAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_EmptyOrganizationsList_DoesNotCallDisableCommand() + { + // Arrange + var providerId = Guid.NewGuid(); + var context = CreateJobExecutionContext(providerId, DateTime.UtcNow); + _providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId) + .Returns(new List()); + + // Act + await _sut.Execute(context); + + // Assert + await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); + } + + private static IJobExecutionContext CreateJobExecutionContext(Guid providerId, DateTime? expirationDate) + { + var context = Substitute.For(); + var jobDataMap = new JobDataMap + { + { "providerId", providerId.ToString() }, + { "expirationDate", expirationDate?.ToString("O") } + }; + context.MergedJobDataMap.Returns(jobDataMap); + return context; + } +} diff --git a/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs b/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs new file mode 100644 index 0000000000..b3540246b0 --- /dev/null +++ b/test/Billing.Test/Jobs/ReconcileAdditionalStorageJobTests.cs @@ -0,0 +1,789 @@ +using Bit.Billing.Jobs; +using Bit.Billing.Services; +using Bit.Core; +using Bit.Core.Billing.Constants; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Quartz; +using Stripe; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class ReconcileAdditionalStorageJobTests +{ + private readonly IStripeFacade _stripeFacade; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; + private readonly ReconcileAdditionalStorageJob _sut; + + public ReconcileAdditionalStorageJobTests() + { + _stripeFacade = Substitute.For(); + _logger = Substitute.For>(); + _featureService = Substitute.For(); + _sut = new ReconcileAdditionalStorageJob(_stripeFacade, _logger, _featureService); + } + + #region Feature Flag Tests + + [Fact] + public async Task Execute_FeatureFlagDisabled_SkipsProcessing() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob) + .Returns(false); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.DidNotReceiveWithAnyArgs().ListSubscriptionsAutoPagingAsync(); + } + + [Fact] + public async Task Execute_FeatureFlagEnabled_ProcessesSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob) + .Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode) + .Returns(false); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Empty()); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.Received(3).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Limit == 100)); + } + + #endregion + + #region Dry Run Mode Tests + + [Fact] + public async Task Execute_DryRunMode_DoesNotUpdateSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(false); // Dry run ON + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_DryRunModeDisabled_UpdatesSubscriptions() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); // Dry run OFF + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => o.Items.Count == 1)); + } + + #endregion + + #region Price ID Processing Tests + + [Fact] + public async Task Execute_ProcessesAllThreePriceIds() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(false); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Empty()); + + // Act + await _sut.Execute(context); + + // Assert + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "storage-gb-monthly")); + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "storage-gb-annually")); + _stripeFacade.Received(1).ListSubscriptionsAutoPagingAsync( + Arg.Is(o => o.Price == "personal-storage-gb-annually")); + } + + #endregion + + #region Already Processed Tests + + [Fact] + public async Task Execute_SubscriptionAlreadyProcessed_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") + }; + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: metadata); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_SubscriptionWithInvalidProcessedDate_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var metadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = "invalid-date" + }; + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: metadata); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_SubscriptionWithoutMetadata_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, metadata: null); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + #endregion + + #region Quantity Reduction Logic Tests + + [Fact] + public async Task Execute_QuantityGreaterThan4_ReducesBy4() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Quantity == 6 && + o.Items[0].Deleted != true)); + } + + [Fact] + public async Task Execute_QuantityEquals4_DeletesItem() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 4); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Deleted == true)); + } + + [Fact] + public async Task Execute_QuantityLessThan4_DeletesItem() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 2); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Deleted == true)); + } + + #endregion + + #region Update Options Tests + + [Fact] + public async Task Execute_UpdateOptions_SetsProrationBehaviorToCreateProrations() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => o.ProrationBehavior == StripeConstants.ProrationBehavior.CreateProrations)); + } + + [Fact] + public async Task Execute_UpdateOptions_SetsReconciledMetadata() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + "sub_123", + Arg.Is(o => + o.Metadata.ContainsKey(StripeConstants.MetadataKeys.StorageReconciled2025) && + !string.IsNullOrEmpty(o.Metadata[StripeConstants.MetadataKeys.StorageReconciled2025]))); + } + + #endregion + + #region Subscription Filtering Tests + + [Fact] + public async Task Execute_SubscriptionWithNoItems_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = new Subscription + { + Id = "sub_123", + Items = null + }; + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_SubscriptionWithDifferentPriceId_SkipsUpdate() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "different-price-id", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_NullSubscription_SkipsProcessing() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(null!)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + #endregion + + #region Multiple Subscriptions Tests + + [Fact] + public async Task Execute_MultipleSubscriptions_ProcessesAll() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_1" => subscription1, + "sub_2" => subscription2, + "sub_3" => subscription3, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + [Fact] + public async Task Execute_MixedSubscriptionsWithProcessed_OnlyProcessesUnprocessed() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var processedMetadata = new Dictionary + { + [StripeConstants.MetadataKeys.StorageReconciled2025] = DateTime.UtcNow.ToString("o") + }; + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5, metadata: processedMetadata); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_1" => subscription1, + "sub_3" => subscription3, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + #endregion + + #region Error Handling Tests + + [Fact] + public async Task Execute_UpdateFails_ContinuesProcessingOthers() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + var subscription2 = CreateSubscription("sub_2", "storage-gb-monthly", quantity: 5); + var subscription3 = CreateSubscription("sub_3", "storage-gb-monthly", quantity: 3); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1, subscription2, subscription3)); + + _stripeFacade.UpdateSubscription("sub_1", Arg.Any()) + .Returns(subscription1); + _stripeFacade.UpdateSubscription("sub_2", Arg.Any()) + .Throws(new Exception("Stripe API error")); + _stripeFacade.UpdateSubscription("sub_3", Arg.Any()) + .Returns(subscription3); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_1", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_2", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_3", Arg.Any()); + } + + [Fact] + public async Task Execute_UpdateFails_LogsError() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Throws(new Exception("Stripe API error")); + + // Act + await _sut.Execute(context); + + // Assert + _logger.Received().Log( + LogLevel.Error, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + #endregion + + #region Subscription Status Filtering Tests + + [Fact] + public async Task Execute_ActiveStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_TrialingStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Trialing); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_PastDueStatusSubscription_ProcessesSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.PastDue); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_123", Arg.Any()); + } + + [Fact] + public async Task Execute_CanceledStatusSubscription_SkipsSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Canceled); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_IncompleteStatusSubscription_SkipsSubscription() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription = CreateSubscription("sub_123", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Incomplete); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription)); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null!); + } + + [Fact] + public async Task Execute_MixedSubscriptionStatuses_OnlyProcessesValidStatuses() + { + // Arrange + var context = CreateJobExecutionContext(); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var activeSubscription = CreateSubscription("sub_active", "storage-gb-monthly", quantity: 10, status: StripeConstants.SubscriptionStatus.Active); + var trialingSubscription = CreateSubscription("sub_trialing", "storage-gb-monthly", quantity: 8, status: StripeConstants.SubscriptionStatus.Trialing); + var pastDueSubscription = CreateSubscription("sub_pastdue", "storage-gb-monthly", quantity: 6, status: StripeConstants.SubscriptionStatus.PastDue); + var canceledSubscription = CreateSubscription("sub_canceled", "storage-gb-monthly", quantity: 5, status: StripeConstants.SubscriptionStatus.Canceled); + var incompleteSubscription = CreateSubscription("sub_incomplete", "storage-gb-monthly", quantity: 4, status: StripeConstants.SubscriptionStatus.Incomplete); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(activeSubscription, trialingSubscription, pastDueSubscription, canceledSubscription, incompleteSubscription)); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(callInfo => callInfo.Arg() switch + { + "sub_active" => activeSubscription, + "sub_trialing" => trialingSubscription, + "sub_pastdue" => pastDueSubscription, + _ => null + }); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription("sub_active", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_trialing", Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription("sub_pastdue", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_canceled", Arg.Any()); + await _stripeFacade.DidNotReceive().UpdateSubscription("sub_incomplete", Arg.Any()); + } + + #endregion + + #region Cancellation Tests + + [Fact] + public async Task Execute_CancellationRequested_LogsWarningAndExits() + { + // Arrange + var cts = new CancellationTokenSource(); + cts.Cancel(); // Cancel immediately + var context = CreateJobExecutionContext(cts.Token); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_EnableReconcileAdditionalStorageJob).Returns(true); + _featureService.IsEnabled(FeatureFlagKeys.PM28265_ReconcileAdditionalStorageJobEnableLiveMode).Returns(true); + + var subscription1 = CreateSubscription("sub_1", "storage-gb-monthly", quantity: 10); + + _stripeFacade.ListSubscriptionsAutoPagingAsync(Arg.Any()) + .Returns(AsyncEnumerable.Create(subscription1)); + + // Act + await _sut.Execute(context); + + // Assert - Should not process any subscriptions due to immediate cancellation + await _stripeFacade.DidNotReceiveWithAnyArgs().UpdateSubscription(null); + _logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + #endregion + + #region Helper Methods + + private static IJobExecutionContext CreateJobExecutionContext(CancellationToken cancellationToken = default) + { + var context = Substitute.For(); + context.CancellationToken.Returns(cancellationToken); + return context; + } + + private static Subscription CreateSubscription( + string id, + string priceId, + long? quantity = null, + Dictionary? metadata = null, + string status = StripeConstants.SubscriptionStatus.Active) + { + var price = new Price { Id = priceId }; + var item = new SubscriptionItem + { + Id = $"si_{id}", + Price = price, + Quantity = quantity ?? 0 + }; + + return new Subscription + { + Id = id, + Status = status, + Metadata = metadata, + Items = new StripeList + { + Data = new List { item } + } + }; + } + + #endregion +} + +internal static class AsyncEnumerable +{ + public static async IAsyncEnumerable Create(params T[] items) + { + foreach (var item in items) + { + yield return item; + } + await Task.CompletedTask; + } + + public static async IAsyncEnumerable Empty() + { + await Task.CompletedTask; + yield break; + } +} diff --git a/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs b/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs new file mode 100644 index 0000000000..03bf24f7ff --- /dev/null +++ b/test/Billing.Test/Jobs/SubscriptionCancellationJobTests.cs @@ -0,0 +1,388 @@ +using Bit.Billing.Constants; +using Bit.Billing.Jobs; +using Bit.Billing.Services; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Repositories; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Quartz; +using Stripe; +using Xunit; + +namespace Bit.Billing.Test.Jobs; + +public class SubscriptionCancellationJobTests +{ + private readonly IStripeFacade _stripeFacade; + private readonly IOrganizationRepository _organizationRepository; + private readonly SubscriptionCancellationJob _sut; + + public SubscriptionCancellationJobTests() + { + _stripeFacade = Substitute.For(); + _organizationRepository = Substitute.For(); + _sut = new SubscriptionCancellationJob(_stripeFacade, _organizationRepository, Substitute.For>()); + } + + [Fact] + public async Task Execute_OrganizationIsNull_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + _organizationRepository.GetByIdAsync(organizationId).Returns((Organization)null); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().GetSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().CancelSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_OrganizationIsEnabled_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = true + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().GetSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().CancelSubscription(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task Execute_SubscriptionStatusIsNotUnpaid_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Active, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceive().CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_BillingReasonIsInvalid_SkipsCancellation() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "manual" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.DidNotReceive().CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_ValidConditions_CancelsSubscriptionAndVoidsInvoices() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = + [ + new Invoice { Id = "inv_1" }, + new Invoice { Id = "inv_2" } + ], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1).VoidInvoice("inv_1"); + await _stripeFacade.Received(1).VoidInvoice("inv_2"); + } + + [Fact] + public async Task Execute_WithSubscriptionCreateBillingReason_CancelsSubscription() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_create" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + } + + [Fact] + public async Task Execute_NoOpenInvoices_CancelsSubscriptionOnly() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.DidNotReceiveWithAnyArgs().VoidInvoice(Arg.Any()); + } + + [Fact] + public async Task Execute_WithPagination_VoidsAllInvoices() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + // First page of invoices + var firstPage = new StripeList + { + Data = + [ + new Invoice { Id = "inv_1" }, + new Invoice { Id = "inv_2" } + ], + HasMore = true + }; + + // Second page of invoices + var secondPage = new StripeList + { + Data = + [ + new Invoice { Id = "inv_3" }, + new Invoice { Id = "inv_4" } + ], + HasMore = false + }; + + _stripeFacade.ListInvoices(Arg.Is(o => o.StartingAfter == null)) + .Returns(firstPage); + _stripeFacade.ListInvoices(Arg.Is(o => o.StartingAfter == "inv_2")) + .Returns(secondPage); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1).VoidInvoice("inv_1"); + await _stripeFacade.Received(1).VoidInvoice("inv_2"); + await _stripeFacade.Received(1).VoidInvoice("inv_3"); + await _stripeFacade.Received(1).VoidInvoice("inv_4"); + await _stripeFacade.Received(2).ListInvoices(Arg.Any()); + } + + [Fact] + public async Task Execute_ListInvoicesCalledWithCorrectOptions() + { + // Arrange + const string subscriptionId = "sub_123"; + var organizationId = Guid.NewGuid(); + var context = CreateJobExecutionContext(subscriptionId, organizationId); + + var organization = new Organization + { + Id = organizationId, + Enabled = false + }; + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Unpaid, + LatestInvoice = new Invoice + { + BillingReason = "subscription_cycle" + } + }; + _stripeFacade.GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))) + .Returns(subscription); + + var invoices = new StripeList + { + Data = [], + HasMore = false + }; + _stripeFacade.ListInvoices(Arg.Any()).Returns(invoices); + + // Act + await _sut.Execute(context); + + // Assert + await _stripeFacade.Received(1).GetSubscription(subscriptionId, Arg.Is(o => o.Expand.Contains("latest_invoice"))); + await _stripeFacade.Received(1).ListInvoices(Arg.Is(o => + o.Status == "open" && + o.Subscription == subscriptionId && + o.Limit == 100)); + } + + private static IJobExecutionContext CreateJobExecutionContext(string subscriptionId, Guid organizationId) + { + var context = Substitute.For(); + var jobDataMap = new JobDataMap + { + { "subscriptionId", subscriptionId }, + { "organizationId", organizationId.ToString() } + }; + context.MergedJobDataMap.Returns(jobDataMap); + return context; + } +} diff --git a/test/Billing.Test/Resources/Events/charge.succeeded.json b/test/Billing.Test/Resources/Events/charge.succeeded.json deleted file mode 100644 index 3cf919f123..0000000000 --- a/test/Billing.Test/Resources/Events/charge.succeeded.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "id": "evt_3NvKgBIGBnsLynRr0pJJqudS", - "object": "event", - "api_version": "2024-06-20", - "created": 1695909300, - "data": { - "object": { - "id": "ch_3NvKgBIGBnsLynRr0ZyvP9AN", - "object": "charge", - "amount": 7200, - "amount_captured": 7200, - "amount_refunded": 0, - "application": null, - "application_fee": null, - "application_fee_amount": null, - "balance_transaction": "txn_3NvKgBIGBnsLynRr0KbYEz76", - "billing_details": { - "address": { - "city": null, - "country": null, - "line1": null, - "line2": null, - "postal_code": null, - "state": null - }, - "email": null, - "name": null, - "phone": null - }, - "calculated_statement_descriptor": "BITWARDEN", - "captured": true, - "created": 1695909299, - "currency": "usd", - "customer": "cus_OimAwOzQmThNXx", - "description": "Subscription update", - "destination": null, - "dispute": null, - "disputed": false, - "failure_balance_transaction": null, - "failure_code": null, - "failure_message": null, - "fraud_details": { - }, - "invoice": "in_1NvKgBIGBnsLynRrmRFHAcoV", - "livemode": false, - "metadata": { - }, - "on_behalf_of": null, - "order": null, - "outcome": { - "network_status": "approved_by_network", - "reason": null, - "risk_level": "normal", - "risk_score": 37, - "seller_message": "Payment complete.", - "type": "authorized" - }, - "paid": true, - "payment_intent": "pi_3NvKgBIGBnsLynRr09Ny3Heu", - "payment_method": "pm_1NvKbpIGBnsLynRrcOwez4A1", - "payment_method_details": { - "card": { - "amount_authorized": 7200, - "brand": "visa", - "checks": { - "address_line1_check": null, - "address_postal_code_check": null, - "cvc_check": "pass" - }, - "country": "US", - "exp_month": 6, - "exp_year": 2033, - "extended_authorization": { - "status": "disabled" - }, - "fingerprint": "0VgUBpvqcUUnuSmK", - "funding": "credit", - "incremental_authorization": { - "status": "unavailable" - }, - "installments": null, - "last4": "4242", - "mandate": null, - "multicapture": { - "status": "unavailable" - }, - "network": "visa", - "network_token": { - "used": false - }, - "overcapture": { - "maximum_amount_capturable": 7200, - "status": "unavailable" - }, - "three_d_secure": null, - "wallet": null - }, - "type": "card" - }, - "receipt_email": "cturnbull@bitwarden.com", - "receipt_number": null, - "receipt_url": "https://pay.stripe.com/receipts/invoices/CAcaFwoVYWNjdF8xOXNtSVhJR0Juc0x5blJyKLSL1qgGMgYTnk_JOUA6LBY_SDEZNtuae1guQ6Dlcuev1TUHwn712t-UNnZdIc383zS15bXv_1dby8e4?s=ap", - "refunded": false, - "refunds": { - "object": "list", - "data": [ - ], - "has_more": false, - "total_count": 0, - "url": "/v1/charges/ch_3NvKgBIGBnsLynRr0ZyvP9AN/refunds" - }, - "review": null, - "shipping": null, - "source": null, - "source_transfer": null, - "statement_descriptor": null, - "statement_descriptor_suffix": null, - "status": "succeeded", - "transfer_data": null, - "transfer_group": null - } - }, - "livemode": false, - "pending_webhooks": 9, - "request": { - "id": "req_rig8N5Ca8EXYRy", - "idempotency_key": "db75068d-5d90-4c65-a410-4e2ed8347509" - }, - "type": "charge.succeeded" -} diff --git a/test/Billing.Test/Resources/Events/customer.subscription.updated.json b/test/Billing.Test/Resources/Events/customer.subscription.updated.json deleted file mode 100644 index 62a8590fa8..0000000000 --- a/test/Billing.Test/Resources/Events/customer.subscription.updated.json +++ /dev/null @@ -1,177 +0,0 @@ -{ - "id": "evt_1NvLMDIGBnsLynRr6oBxebrE", - "object": "event", - "api_version": "2024-06-20", - "created": 1695911902, - "data": { - "object": { - "id": "sub_1NvKoKIGBnsLynRrcLIAUWGf", - "object": "subscription", - "application": null, - "application_fee_percent": null, - "automatic_tax": { - "enabled": false - }, - "billing_cycle_anchor": 1695911900, - "billing_thresholds": null, - "cancel_at": null, - "cancel_at_period_end": false, - "canceled_at": null, - "cancellation_details": { - "comment": null, - "feedback": null, - "reason": null - }, - "collection_method": "charge_automatically", - "created": 1695909804, - "currency": "usd", - "current_period_end": 1727534300, - "current_period_start": 1695911900, - "customer": "cus_OimNNCC3RiI2HQ", - "days_until_due": null, - "default_payment_method": null, - "default_source": null, - "default_tax_rates": [ - ], - "description": null, - "discount": null, - "ended_at": null, - "items": { - "object": "list", - "data": [ - { - "id": "si_OimNgVtrESpqus", - "object": "subscription_item", - "billing_thresholds": null, - "created": 1695909805, - "metadata": { - }, - "plan": { - "id": "enterprise-org-seat-annually", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 3600, - "amount_decimal": "3600", - "billing_scheme": "per_unit", - "created": 1494268677, - "currency": "usd", - "interval": "year", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "nickname": "2019 Enterprise Seat (Annually)", - "product": "prod_BUtogGemxnTi9z", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "price": { - "id": "enterprise-org-seat-annually", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1494268677, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "2019 Enterprise Seat (Annually)", - "product": "prod_BUtogGemxnTi9z", - "recurring": { - "aggregate_usage": null, - "interval": "year", - "interval_count": 1, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "unspecified", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 3600, - "unit_amount_decimal": "3600" - }, - "quantity": 1, - "subscription": "sub_1NvKoKIGBnsLynRrcLIAUWGf", - "tax_rates": [ - ] - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/subscription_items?subscription=sub_1NvKoKIGBnsLynRrcLIAUWGf" - }, - "latest_invoice": "in_1NvLM9IGBnsLynRrOysII07d", - "livemode": false, - "metadata": { - "organizationId": "84a569ea-4643-474a-83a9-b08b00e7a20d" - }, - "next_pending_invoice_item_invoice": null, - "on_behalf_of": null, - "pause_collection": null, - "payment_settings": { - "payment_method_options": null, - "payment_method_types": null, - "save_default_payment_method": "off" - }, - "pending_invoice_item_interval": null, - "pending_setup_intent": null, - "pending_update": null, - "plan": { - "id": "enterprise-org-seat-annually", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 3600, - "amount_decimal": "3600", - "billing_scheme": "per_unit", - "created": 1494268677, - "currency": "usd", - "interval": "year", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "nickname": "2019 Enterprise Seat (Annually)", - "product": "prod_BUtogGemxnTi9z", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "quantity": 1, - "schedule": null, - "start_date": 1695909804, - "status": "active", - "test_clock": null, - "transfer_data": null, - "trial_end": 1695911899, - "trial_settings": { - "end_behavior": { - "missing_payment_method": "create_invoice" - } - }, - "trial_start": 1695909804 - }, - "previous_attributes": { - "billing_cycle_anchor": 1696514604, - "current_period_end": 1696514604, - "current_period_start": 1695909804, - "latest_invoice": "in_1NvKoKIGBnsLynRrSNRC6oYI", - "status": "trialing", - "trial_end": 1696514604 - } - }, - "livemode": false, - "pending_webhooks": 8, - "request": { - "id": "req_DMZPUU3BI66zAx", - "idempotency_key": "3fd8b4a5-6a20-46ab-9f45-b37b02a8017f" - }, - "type": "customer.subscription.updated" -} diff --git a/test/Billing.Test/Resources/Events/customer.updated.json b/test/Billing.Test/Resources/Events/customer.updated.json deleted file mode 100644 index 9aa0928515..0000000000 --- a/test/Billing.Test/Resources/Events/customer.updated.json +++ /dev/null @@ -1,311 +0,0 @@ -{ - "id": "evt_1NvKjSIGBnsLynRrS3MTK4DZ", - "object": "event", - "account": "acct_19smIXIGBnsLynRr", - "api_version": "2024-06-20", - "created": 1695909502, - "data": { - "object": { - "id": "cus_Of54kUr3gV88lM", - "object": "customer", - "address": { - "city": null, - "country": "US", - "line1": "", - "line2": null, - "postal_code": "33701", - "state": null - }, - "balance": 0, - "created": 1695056798, - "currency": "usd", - "default_source": "src_1NtAfeIGBnsLynRrYDrceax7", - "delinquent": false, - "description": "Premium User", - "discount": null, - "email": "premium@bitwarden.com", - "invoice_prefix": "C506E8CE", - "invoice_settings": { - "custom_fields": [ - { - "name": "Subscriber", - "value": "Premium User" - } - ], - "default_payment_method": "pm_1Nrku9IGBnsLynRrcsQ3hy6C", - "footer": null, - "rendering_options": null - }, - "livemode": false, - "metadata": { - "region": "US" - }, - "name": null, - "next_invoice_sequence": 2, - "phone": null, - "preferred_locales": [ - ], - "shipping": null, - "tax_exempt": "none", - "test_clock": null, - "account_balance": 0, - "cards": { - "object": "list", - "data": [ - ], - "has_more": false, - "total_count": 0, - "url": "/v1/customers/cus_Of54kUr3gV88lM/cards" - }, - "default_card": null, - "default_currency": "usd", - "sources": { - "object": "list", - "data": [ - { - "id": "src_1NtAfeIGBnsLynRrYDrceax7", - "object": "source", - "ach_credit_transfer": { - "account_number": "test_b2d1c6415f6f", - "routing_number": "110000000", - "fingerprint": "ePO4hBQanSft3gvU", - "swift_code": "TSTEZ122", - "bank_name": "TEST BANK", - "refund_routing_number": null, - "refund_account_holder_type": null, - "refund_account_holder_name": null - }, - "amount": null, - "client_secret": "src_client_secret_bUAP2uDRw6Pwj0xYk32LmJ3K", - "created": 1695394170, - "currency": "usd", - "customer": "cus_Of54kUr3gV88lM", - "flow": "receiver", - "livemode": false, - "metadata": { - }, - "owner": { - "address": null, - "email": "amount_0@stripe.com", - "name": null, - "phone": null, - "verified_address": null, - "verified_email": null, - "verified_name": null, - "verified_phone": null - }, - "receiver": { - "address": "110000000-test_b2d1c6415f6f", - "amount_charged": 0, - "amount_received": 0, - "amount_returned": 0, - "refund_attributes_method": "email", - "refund_attributes_status": "missing" - }, - "statement_descriptor": null, - "status": "pending", - "type": "ach_credit_transfer", - "usage": "reusable" - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/customers/cus_Of54kUr3gV88lM/sources" - }, - "subscriptions": { - "object": "list", - "data": [ - { - "id": "sub_1NrkuBIGBnsLynRrzjFGIjEw", - "object": "subscription", - "application": null, - "application_fee_percent": null, - "automatic_tax": { - "enabled": false - }, - "billing": "charge_automatically", - "billing_cycle_anchor": 1695056799, - "billing_thresholds": null, - "cancel_at": null, - "cancel_at_period_end": false, - "canceled_at": null, - "cancellation_details": { - "comment": null, - "feedback": null, - "reason": null - }, - "collection_method": "charge_automatically", - "created": 1695056799, - "currency": "usd", - "current_period_end": 1726679199, - "current_period_start": 1695056799, - "customer": "cus_Of54kUr3gV88lM", - "days_until_due": null, - "default_payment_method": null, - "default_source": null, - "default_tax_rates": [ - ], - "description": null, - "discount": null, - "ended_at": null, - "invoice_customer_balance_settings": { - "consume_applied_balance_on_void": true - }, - "items": { - "object": "list", - "data": [ - { - "id": "si_Of54i3aK9I5Wro", - "object": "subscription_item", - "billing_thresholds": null, - "created": 1695056800, - "metadata": { - }, - "plan": { - "id": "premium-annually", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 1000, - "amount_decimal": "1000", - "billing_scheme": "per_unit", - "created": 1499289328, - "currency": "usd", - "interval": "year", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "name": "Premium (Annually)", - "nickname": "Premium (Annually)", - "product": "prod_BUqgYr48VzDuCg", - "statement_description": null, - "statement_descriptor": null, - "tiers": null, - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "price": { - "id": "premium-annually", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1499289328, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "Premium (Annually)", - "product": "prod_BUqgYr48VzDuCg", - "recurring": { - "aggregate_usage": null, - "interval": "year", - "interval_count": 1, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "unspecified", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 1000, - "unit_amount_decimal": "1000" - }, - "quantity": 1, - "subscription": "sub_1NrkuBIGBnsLynRrzjFGIjEw", - "tax_rates": [ - ] - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/subscription_items?subscription=sub_1NrkuBIGBnsLynRrzjFGIjEw" - }, - "latest_invoice": "in_1NrkuBIGBnsLynRr40gyJTVU", - "livemode": false, - "metadata": { - "userId": "91f40b6d-ac3b-4348-804b-b0810119ac6a" - }, - "next_pending_invoice_item_invoice": null, - "on_behalf_of": null, - "pause_collection": null, - "payment_settings": { - "payment_method_options": null, - "payment_method_types": null, - "save_default_payment_method": "off" - }, - "pending_invoice_item_interval": null, - "pending_setup_intent": null, - "pending_update": null, - "plan": { - "id": "premium-annually", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 1000, - "amount_decimal": "1000", - "billing_scheme": "per_unit", - "created": 1499289328, - "currency": "usd", - "interval": "year", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "name": "Premium (Annually)", - "nickname": "Premium (Annually)", - "product": "prod_BUqgYr48VzDuCg", - "statement_description": null, - "statement_descriptor": null, - "tiers": null, - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "quantity": 1, - "schedule": null, - "start": 1695056799, - "start_date": 1695056799, - "status": "active", - "tax_percent": null, - "test_clock": null, - "transfer_data": null, - "trial_end": null, - "trial_settings": { - "end_behavior": { - "missing_payment_method": "create_invoice" - } - }, - "trial_start": null - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/customers/cus_Of54kUr3gV88lM/subscriptions" - }, - "tax_ids": { - "object": "list", - "data": [ - ], - "has_more": false, - "total_count": 0, - "url": "/v1/customers/cus_Of54kUr3gV88lM/tax_ids" - }, - "tax_info": null, - "tax_info_verification": null - }, - "previous_attributes": { - "email": "premium-new@bitwarden.com" - } - }, - "livemode": false, - "pending_webhooks": 5, - "request": "req_2RtGdXCfiicFLx", - "type": "customer.updated", - "user_id": "acct_19smIXIGBnsLynRr" -} diff --git a/test/Billing.Test/Resources/Events/invoice.created.json b/test/Billing.Test/Resources/Events/invoice.created.json deleted file mode 100644 index bf53372b51..0000000000 --- a/test/Billing.Test/Resources/Events/invoice.created.json +++ /dev/null @@ -1,222 +0,0 @@ -{ - "id": "evt_1NvKzfIGBnsLynRr0SkwrlkE", - "object": "event", - "api_version": "2024-06-20", - "created": 1695910506, - "data": { - "object": { - "id": "in_1NvKzdIGBnsLynRr8fE8cpbg", - "object": "invoice", - "account_country": "US", - "account_name": "Bitwarden Inc.", - "account_tax_ids": null, - "amount_due": 0, - "amount_paid": 0, - "amount_remaining": 0, - "amount_shipping": 0, - "application": null, - "application_fee_amount": null, - "attempt_count": 0, - "attempted": true, - "auto_advance": false, - "automatic_tax": { - "enabled": false, - "status": null - }, - "billing_reason": "subscription_create", - "charge": null, - "collection_method": "charge_automatically", - "created": 1695910505, - "currency": "usd", - "custom_fields": [ - { - "name": "Organization", - "value": "teams 2023 monthly - 2" - } - ], - "customer": "cus_OimYrxnMTMMK1E", - "customer_address": { - "city": null, - "country": "US", - "line1": "", - "line2": null, - "postal_code": "12345", - "state": null - }, - "customer_email": "cturnbull@bitwarden.com", - "customer_name": null, - "customer_phone": null, - "customer_shipping": null, - "customer_tax_exempt": "none", - "customer_tax_ids": [ - ], - "default_payment_method": null, - "default_source": null, - "default_tax_rates": [ - ], - "description": null, - "discount": null, - "discounts": [ - ], - "due_date": null, - "effective_at": 1695910505, - "ending_balance": 0, - "footer": null, - "from_invoice": null, - "hosted_invoice_url": "https://invoice.stripe.com/i/acct_19smIXIGBnsLynRr/test_YWNjdF8xOXNtSVhJR0Juc0x5blJyLF9PaW1ZVlo4dFRtbkNQQVY5aHNpckQxN1QzRHBPcVBOLDg2NDUxMzA30200etYRHca2?s=ap", - "invoice_pdf": "https://pay.stripe.com/invoice/acct_19smIXIGBnsLynRr/test_YWNjdF8xOXNtSVhJR0Juc0x5blJyLF9PaW1ZVlo4dFRtbkNQQVY5aHNpckQxN1QzRHBPcVBOLDg2NDUxMzA30200etYRHca2/pdf?s=ap", - "last_finalization_error": null, - "latest_revision": null, - "lines": { - "object": "list", - "data": [ - { - "id": "il_1NvKzdIGBnsLynRr2pS4ZA8e", - "object": "line_item", - "amount": 0, - "amount_excluding_tax": 0, - "currency": "usd", - "description": "Trial period for Teams Organization Seat", - "discount_amounts": [ - ], - "discountable": true, - "discounts": [ - ], - "livemode": false, - "metadata": { - "organizationId": "3fbc84ce-102d-4919-b89b-b08b00ead71a" - }, - "period": { - "end": 1696515305, - "start": 1695910505 - }, - "plan": { - "id": "2020-teams-org-seat-monthly", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 400, - "amount_decimal": "400", - "billing_scheme": "per_unit", - "created": 1595263113, - "currency": "usd", - "interval": "month", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "nickname": "Teams Organization Seat (Monthly) 2023", - "product": "prod_HgOooYXDr2DDAA", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "price": { - "id": "2020-teams-org-seat-monthly", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1595263113, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "Teams Organization Seat (Monthly) 2023", - "product": "prod_HgOooYXDr2DDAA", - "recurring": { - "aggregate_usage": null, - "interval": "month", - "interval_count": 1, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "unspecified", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 400, - "unit_amount_decimal": "400" - }, - "proration": false, - "proration_details": { - "credited_items": null - }, - "quantity": 1, - "subscription": "sub_1NvKzdIGBnsLynRrKIHQamZc", - "subscription_item": "si_OimYNSbvuqdtTr", - "tax_amounts": [ - ], - "tax_rates": [ - ], - "type": "subscription", - "unit_amount_excluding_tax": "0" - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/invoices/in_1NvKzdIGBnsLynRr8fE8cpbg/lines" - }, - "livemode": false, - "metadata": { - }, - "next_payment_attempt": null, - "number": "3E96D078-0001", - "on_behalf_of": null, - "paid": true, - "paid_out_of_band": false, - "payment_intent": null, - "payment_settings": { - "default_mandate": null, - "payment_method_options": null, - "payment_method_types": null - }, - "period_end": 1695910505, - "period_start": 1695910505, - "post_payment_credit_notes_amount": 0, - "pre_payment_credit_notes_amount": 0, - "quote": null, - "receipt_number": null, - "rendering": null, - "rendering_options": null, - "shipping_cost": null, - "shipping_details": null, - "starting_balance": 0, - "statement_descriptor": null, - "status": "paid", - "status_transitions": { - "finalized_at": 1695910505, - "marked_uncollectible_at": null, - "paid_at": 1695910505, - "voided_at": null - }, - "subscription": "sub_1NvKzdIGBnsLynRrKIHQamZc", - "subscription_details": { - "metadata": { - "organizationId": "3fbc84ce-102d-4919-b89b-b08b00ead71a" - } - }, - "subtotal": 0, - "subtotal_excluding_tax": 0, - "tax": null, - "test_clock": null, - "total": 0, - "total_discount_amounts": [ - ], - "total_excluding_tax": 0, - "total_tax_amounts": [ - ], - "transfer_data": null, - "webhooks_delivered_at": null - } - }, - "livemode": false, - "pending_webhooks": 8, - "request": { - "id": "req_roIwONfgyfZdr4", - "idempotency_key": "dd2a171b-b9c7-4d2d-89d5-1ceae3c0595d" - }, - "type": "invoice.created" -} diff --git a/test/Billing.Test/Resources/Events/invoice.finalized.json b/test/Billing.Test/Resources/Events/invoice.finalized.json deleted file mode 100644 index 207fab497e..0000000000 --- a/test/Billing.Test/Resources/Events/invoice.finalized.json +++ /dev/null @@ -1,400 +0,0 @@ -{ - "id": "evt_1PQaABIGBnsLynRrhoJjGnyz", - "object": "event", - "account": "acct_19smIXIGBnsLynRr", - "api_version": "2024-06-20", - "created": 1718133319, - "data": { - "object": { - "id": "in_1PQa9fIGBnsLynRraYIqTdBs", - "object": "invoice", - "account_country": "US", - "account_name": "Bitwarden Inc.", - "account_tax_ids": null, - "amount_due": 84240, - "amount_paid": 0, - "amount_remaining": 84240, - "amount_shipping": 0, - "application": null, - "attempt_count": 0, - "attempted": false, - "auto_advance": true, - "automatic_tax": { - "enabled": true, - "liability": { - "type": "self" - }, - "status": "complete" - }, - "billing_reason": "subscription_update", - "charge": null, - "collection_method": "send_invoice", - "created": 1718133291, - "currency": "usd", - "custom_fields": [ - { - "name": "Provider", - "value": "MSP" - } - ], - "customer": "cus_QH8QVKyTh2lfcG", - "customer_address": { - "city": null, - "country": "US", - "line1": null, - "line2": null, - "postal_code": "12345", - "state": null - }, - "customer_email": "billing@msp.com", - "customer_name": null, - "customer_phone": null, - "customer_shipping": null, - "customer_tax_exempt": "none", - "customer_tax_ids": [ - ], - "default_payment_method": null, - "default_source": null, - "default_tax_rates": [ - ], - "description": null, - "discount": { - "id": "di_1PQa9eIGBnsLynRrwwYr2bGD", - "object": "discount", - "checkout_session": null, - "coupon": { - "id": "msp-discount-35", - "object": "coupon", - "amount_off": null, - "created": 1678805729, - "currency": null, - "duration": "forever", - "duration_in_months": null, - "livemode": false, - "max_redemptions": null, - "metadata": { - }, - "name": "MSP Discount - 35%", - "percent_off": 35, - "redeem_by": null, - "times_redeemed": 515, - "valid": true, - "percent_off_precise": 35 - }, - "customer": "cus_QH8QVKyTh2lfcG", - "end": null, - "invoice": null, - "invoice_item": null, - "promotion_code": null, - "start": 1718133290, - "subscription": null, - "subscription_item": null - }, - "discounts": [ - "di_1PQa9eIGBnsLynRrwwYr2bGD" - ], - "due_date": 1720725291, - "effective_at": 1718136893, - "ending_balance": 0, - "footer": null, - "from_invoice": null, - "hosted_invoice_url": "https://invoice.stripe.com/i/acct_19smIXIGBnsLynRr/test_YWNjdF8xOXNtSVhJR0Juc0x5blJyLF9RSDhRYVNIejNDMXBMVXAzM0M3S2RwaUt1Z3NuVHVzLDEwODY3NDEyMg0200RT8cC2nw?s=ap", - "invoice_pdf": "https://pay.stripe.com/invoice/acct_19smIXIGBnsLynRr/test_YWNjdF8xOXNtSVhJR0Juc0x5blJyLF9RSDhRYVNIejNDMXBMVXAzM0M3S2RwaUt1Z3NuVHVzLDEwODY3NDEyMg0200RT8cC2nw/pdf?s=ap", - "issuer": { - "type": "self" - }, - "last_finalization_error": null, - "latest_revision": null, - "lines": { - "object": "list", - "data": [ - { - "id": "sub_1PQa9fIGBnsLynRr83lNrFHa", - "object": "line_item", - "amount": 50000, - "amount_excluding_tax": 50000, - "currency": "usd", - "description": null, - "discount_amounts": [ - { - "amount": 17500, - "discount": "di_1PQa9eIGBnsLynRrwwYr2bGD" - } - ], - "discountable": true, - "discounts": [ - ], - "invoice": "in_1PQa9fIGBnsLynRraYIqTdBs", - "livemode": false, - "metadata": { - }, - "period": { - "end": 1720725291, - "start": 1718133291 - }, - "plan": { - "id": "2023-teams-org-seat-monthly", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 500, - "amount_decimal": "500", - "billing_scheme": "per_unit", - "created": 1695839010, - "currency": "usd", - "interval": "month", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "meter": null, - "nickname": "Teams Organization Seat (Monthly)", - "product": "prod_HgOooYXDr2DDAA", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed", - "name": "Password Manager - Teams Plan", - "statement_description": null, - "statement_descriptor": null, - "tiers": null - }, - "price": { - "id": "2023-teams-org-seat-monthly", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1695839010, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "Teams Organization Seat (Monthly)", - "product": "prod_HgOooYXDr2DDAA", - "recurring": { - "aggregate_usage": null, - "interval": "month", - "interval_count": 1, - "meter": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "exclusive", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 500, - "unit_amount_decimal": "500" - }, - "proration": false, - "proration_details": { - "credited_items": null - }, - "quantity": 100, - "subscription": null, - "subscription_item": "si_QH8Qo4WEJxOVwx", - "tax_amounts": [ - { - "amount": 2600, - "inclusive": false, - "tax_rate": "txr_1OZyBuIGBnsLynRrX0PJLuMC", - "taxability_reason": "standard_rated", - "taxable_amount": 32500 - } - ], - "tax_rates": [ - ], - "type": "subscription", - "unit_amount_excluding_tax": "500", - "unique_id": "il_1PQa9fIGBnsLynRrSJ3cxrdU", - "unique_line_item_id": "sli_1acb3eIGBnsLynRr4b9c2f48" - }, - { - "id": "sub_1PQa9fIGBnsLynRr83lNrFHa", - "object": "line_item", - "amount": 70000, - "amount_excluding_tax": 70000, - "currency": "usd", - "description": null, - "discount_amounts": [ - { - "amount": 24500, - "discount": "di_1PQa9eIGBnsLynRrwwYr2bGD" - } - ], - "discountable": true, - "discounts": [ - ], - "invoice": "in_1PQa9fIGBnsLynRraYIqTdBs", - "livemode": false, - "metadata": { - }, - "period": { - "end": 1720725291, - "start": 1718133291 - }, - "plan": { - "id": "2023-enterprise-seat-monthly", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 700, - "amount_decimal": "700", - "billing_scheme": "per_unit", - "created": 1695152194, - "currency": "usd", - "interval": "month", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "meter": null, - "nickname": "Enterprise Organization (Monthly)", - "product": "prod_HgSOgzUlYDFOzf", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed", - "name": "Password Manager - Enterprise Plan", - "statement_description": null, - "statement_descriptor": null, - "tiers": null - }, - "price": { - "id": "2023-enterprise-seat-monthly", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1695152194, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "Enterprise Organization (Monthly)", - "product": "prod_HgSOgzUlYDFOzf", - "recurring": { - "aggregate_usage": null, - "interval": "month", - "interval_count": 1, - "meter": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "exclusive", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 700, - "unit_amount_decimal": "700" - }, - "proration": false, - "proration_details": { - "credited_items": null - }, - "quantity": 100, - "subscription": null, - "subscription_item": "si_QH8QUjtceXvcis", - "tax_amounts": [ - { - "amount": 3640, - "inclusive": false, - "tax_rate": "txr_1OZyBuIGBnsLynRrX0PJLuMC", - "taxability_reason": "standard_rated", - "taxable_amount": 45500 - } - ], - "tax_rates": [ - ], - "type": "subscription", - "unit_amount_excluding_tax": "700", - "unique_id": "il_1PQa9fIGBnsLynRrVviet37m", - "unique_line_item_id": "sli_11b229IGBnsLynRr837b79d0" - } - ], - "has_more": false, - "total_count": 2, - "url": "/v1/invoices/in_1PQa9fIGBnsLynRraYIqTdBs/lines" - }, - "livemode": false, - "metadata": { - }, - "next_payment_attempt": null, - "number": "525EB050-0001", - "on_behalf_of": null, - "paid": false, - "paid_out_of_band": false, - "payment_intent": "pi_3PQaA7IGBnsLynRr1swr9XJE", - "payment_settings": { - "default_mandate": null, - "payment_method_options": null, - "payment_method_types": null - }, - "period_end": 1718133291, - "period_start": 1718133291, - "post_payment_credit_notes_amount": 0, - "pre_payment_credit_notes_amount": 0, - "quote": null, - "receipt_number": null, - "rendering": null, - "rendering_options": null, - "shipping_cost": null, - "shipping_details": null, - "starting_balance": 0, - "statement_descriptor": null, - "status": "open", - "status_transitions": { - "finalized_at": 1718136893, - "marked_uncollectible_at": null, - "paid_at": null, - "voided_at": null - }, - "subscription": "sub_1PQa9fIGBnsLynRr83lNrFHa", - "subscription_details": { - "metadata": { - "providerId": "655bc5a3-2332-4201-a9a6-b18c013d0572" - } - }, - "subtotal": 120000, - "subtotal_excluding_tax": 120000, - "tax": 6240, - "test_clock": "clock_1PQaA4IGBnsLynRrptkZjgxc", - "total": 84240, - "total_discount_amounts": [ - { - "amount": 42000, - "discount": "di_1PQa9eIGBnsLynRrwwYr2bGD" - } - ], - "total_excluding_tax": 78000, - "total_tax_amounts": [ - { - "amount": 6240, - "inclusive": false, - "tax_rate": "txr_1OZyBuIGBnsLynRrX0PJLuMC", - "taxability_reason": "standard_rated", - "taxable_amount": 78000 - } - ], - "transfer_data": null, - "webhooks_delivered_at": 1718133293, - "application_fee": null, - "billing": "send_invoice", - "closed": false, - "date": 1718133291, - "finalized_at": 1718136893, - "forgiven": false, - "payment": null, - "statement_description": null, - "tax_percent": 8 - } - }, - "livemode": false, - "pending_webhooks": 5, - "request": null, - "type": "invoice.finalized", - "user_id": "acct_19smIXIGBnsLynRr" -} diff --git a/test/Billing.Test/Resources/Events/invoice.upcoming.json b/test/Billing.Test/Resources/Events/invoice.upcoming.json deleted file mode 100644 index 1ecf2c616d..0000000000 --- a/test/Billing.Test/Resources/Events/invoice.upcoming.json +++ /dev/null @@ -1,225 +0,0 @@ -{ - "id": "evt_1Nv0w8IGBnsLynRrZoDVI44u", - "object": "event", - "api_version": "2024-06-20", - "created": 1695833408, - "data": { - "object": { - "object": "invoice", - "account_country": "US", - "account_name": "Bitwarden Inc.", - "account_tax_ids": null, - "amount_due": 0, - "amount_paid": 0, - "amount_remaining": 0, - "amount_shipping": 0, - "application": null, - "application_fee_amount": null, - "attempt_count": 0, - "attempted": false, - "automatic_tax": { - "enabled": true, - "status": "complete" - }, - "billing_reason": "upcoming", - "charge": null, - "collection_method": "charge_automatically", - "created": 1697128681, - "currency": "usd", - "custom_fields": null, - "customer": "cus_M8DV9wiyNa2JxQ", - "customer_address": { - "city": null, - "country": "US", - "line1": "", - "line2": null, - "postal_code": "90019", - "state": null - }, - "customer_email": "vphan@bitwarden.com", - "customer_name": null, - "customer_phone": null, - "customer_shipping": null, - "customer_tax_exempt": "none", - "customer_tax_ids": [ - ], - "default_payment_method": null, - "default_source": null, - "default_tax_rates": [ - ], - "description": null, - "discount": null, - "discounts": [ - ], - "due_date": null, - "effective_at": null, - "ending_balance": -6779, - "footer": null, - "from_invoice": null, - "last_finalization_error": null, - "latest_revision": null, - "lines": { - "object": "list", - "data": [ - { - "id": "il_tmp_12b5e8IGBnsLynRr1996ac3a", - "object": "line_item", - "amount": 2000, - "amount_excluding_tax": 2000, - "currency": "usd", - "description": "5 × 2019 Enterprise Seat (Monthly) (at $4.00 / month)", - "discount_amounts": [ - ], - "discountable": true, - "discounts": [ - ], - "livemode": false, - "metadata": { - }, - "period": { - "end": 1699807081, - "start": 1697128681 - }, - "plan": { - "id": "enterprise-org-seat-monthly", - "object": "plan", - "active": true, - "aggregate_usage": null, - "amount": 400, - "amount_decimal": "400", - "billing_scheme": "per_unit", - "created": 1494268635, - "currency": "usd", - "interval": "month", - "interval_count": 1, - "livemode": false, - "metadata": { - }, - "nickname": "2019 Enterprise Seat (Monthly)", - "product": "prod_BVButYytPSlgs6", - "tiers_mode": null, - "transform_usage": null, - "trial_period_days": null, - "usage_type": "licensed" - }, - "price": { - "id": "enterprise-org-seat-monthly", - "object": "price", - "active": true, - "billing_scheme": "per_unit", - "created": 1494268635, - "currency": "usd", - "custom_unit_amount": null, - "livemode": false, - "lookup_key": null, - "metadata": { - }, - "nickname": "2019 Enterprise Seat (Monthly)", - "product": "prod_BVButYytPSlgs6", - "recurring": { - "aggregate_usage": null, - "interval": "month", - "interval_count": 1, - "trial_period_days": null, - "usage_type": "licensed" - }, - "tax_behavior": "unspecified", - "tiers_mode": null, - "transform_quantity": null, - "type": "recurring", - "unit_amount": 400, - "unit_amount_decimal": "400" - }, - "proration": false, - "proration_details": { - "credited_items": null - }, - "quantity": 5, - "subscription": "sub_1NQxz4IGBnsLynRr1KbitG7v", - "subscription_item": "si_ODOmLnPDHBuMxX", - "tax_amounts": [ - { - "amount": 0, - "inclusive": false, - "tax_rate": "txr_1N6XCyIGBnsLynRr0LHs4AUD", - "taxability_reason": "product_exempt", - "taxable_amount": 0 - } - ], - "tax_rates": [ - ], - "type": "subscription", - "unit_amount_excluding_tax": "400" - } - ], - "has_more": false, - "total_count": 1, - "url": "/v1/invoices/upcoming/lines?customer=cus_M8DV9wiyNa2JxQ&subscription=sub_1NQxz4IGBnsLynRr1KbitG7v" - }, - "livemode": false, - "metadata": { - }, - "next_payment_attempt": 1697132281, - "number": null, - "on_behalf_of": null, - "paid": false, - "paid_out_of_band": false, - "payment_intent": null, - "payment_settings": { - "default_mandate": null, - "payment_method_options": null, - "payment_method_types": null - }, - "period_end": 1697128681, - "period_start": 1694536681, - "post_payment_credit_notes_amount": 0, - "pre_payment_credit_notes_amount": 0, - "quote": null, - "receipt_number": null, - "rendering": null, - "rendering_options": null, - "shipping_cost": null, - "shipping_details": null, - "starting_balance": -8779, - "statement_descriptor": null, - "status": "draft", - "status_transitions": { - "finalized_at": null, - "marked_uncollectible_at": null, - "paid_at": null, - "voided_at": null - }, - "subscription": "sub_1NQxz4IGBnsLynRr1KbitG7v", - "subscription_details": { - "metadata": { - } - }, - "subtotal": 2000, - "subtotal_excluding_tax": 2000, - "tax": 0, - "test_clock": null, - "total": 2000, - "total_discount_amounts": [ - ], - "total_excluding_tax": 2000, - "total_tax_amounts": [ - { - "amount": 0, - "inclusive": false, - "tax_rate": "txr_1N6XCyIGBnsLynRr0LHs4AUD", - "taxability_reason": "product_exempt", - "taxable_amount": 0 - } - ], - "transfer_data": null, - "webhooks_delivered_at": null - } - }, - "livemode": false, - "pending_webhooks": 5, - "request": { - "id": null, - "idempotency_key": null - }, - "type": "invoice.upcoming" -} diff --git a/test/Billing.Test/Resources/Events/payment_method.attached.json b/test/Billing.Test/Resources/Events/payment_method.attached.json deleted file mode 100644 index 2d22a929d4..0000000000 --- a/test/Billing.Test/Resources/Events/payment_method.attached.json +++ /dev/null @@ -1,63 +0,0 @@ -{ - "id": "evt_1NvKzcIGBnsLynRrPJ3hybkd", - "object": "event", - "api_version": "2024-06-20", - "created": 1695910504, - "data": { - "object": { - "id": "pm_1NvKzbIGBnsLynRry6x7Buvc", - "object": "payment_method", - "billing_details": { - "address": { - "city": null, - "country": null, - "line1": null, - "line2": null, - "postal_code": null, - "state": null - }, - "email": null, - "name": null, - "phone": null - }, - "card": { - "brand": "visa", - "checks": { - "address_line1_check": null, - "address_postal_code_check": null, - "cvc_check": "pass" - }, - "country": "US", - "exp_month": 6, - "exp_year": 2033, - "fingerprint": "0VgUBpvqcUUnuSmK", - "funding": "credit", - "generated_from": null, - "last4": "4242", - "networks": { - "available": [ - "visa" - ], - "preferred": null - }, - "three_d_secure_usage": { - "supported": true - }, - "wallet": null - }, - "created": 1695910503, - "customer": "cus_OimYrxnMTMMK1E", - "livemode": false, - "metadata": { - }, - "type": "card" - } - }, - "livemode": false, - "pending_webhooks": 7, - "request": { - "id": "req_2WslNSBD9wAV5v", - "idempotency_key": "db1a648a-3445-47b3-a403-9f3d1303a880" - }, - "type": "payment_method.attached" -} diff --git a/test/Billing.Test/Services/ProviderEventServiceTests.cs b/test/Billing.Test/Services/ProviderEventServiceTests.cs index 7d95157bd2..34c69b95c2 100644 --- a/test/Billing.Test/Services/ProviderEventServiceTests.cs +++ b/test/Billing.Test/Services/ProviderEventServiceTests.cs @@ -1,6 +1,5 @@ using Bit.Billing.Services; using Bit.Billing.Services.Implementations; -using Bit.Billing.Test.Utilities; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; @@ -10,7 +9,7 @@ using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Enums; using Bit.Core.Repositories; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using NSubstitute; using Stripe; using Xunit; @@ -59,29 +58,69 @@ public class ProviderEventServiceTests public async Task TryRecordInvoiceLineItems_EventTypeNotInvoiceCreatedOrInvoiceFinalized_NoOp() { // Arrange - var stripeEvent = await StripeTestEvents.GetAsync(StripeEventType.PaymentMethodAttached); + var stripeEvent = new Event { Type = "payment_method.attached" }; // Act await _providerEventService.TryRecordInvoiceLineItems(stripeEvent); // Assert - await _stripeEventService.DidNotReceiveWithAnyArgs().GetInvoice(Arg.Any()); + await _stripeEventService.DidNotReceiveWithAnyArgs().GetInvoice(Arg.Any(), Arg.Any(), Arg.Any?>()); + } + + [Fact] + public async Task TryRecordInvoiceLineItems_InvoiceParentTypeNotSubscriptionDetails_NoOp() + { + // Arrange + var stripeEvent = new Event + { + Type = "invoice.created" + }; + + var invoice = new Invoice + { + Parent = new InvoiceParent + { + Type = "credit_note", + SubscriptionDetails = new InvoiceParentSubscriptionDetails + { + SubscriptionId = "sub_1" + } + } + }; + + _stripeEventService.GetInvoice(stripeEvent, true, Arg.Any?>()).Returns(invoice); + + // Act + await _providerEventService.TryRecordInvoiceLineItems(stripeEvent); + + // Assert + await _stripeFacade.DidNotReceiveWithAnyArgs().GetSubscription(Arg.Any()); } [Fact] public async Task TryRecordInvoiceLineItems_EventNotProviderRelated_NoOp() { // Arrange - var stripeEvent = await StripeTestEvents.GetAsync(StripeEventType.InvoiceCreated); + var stripeEvent = new Event + { + Type = "invoice.created" + }; const string subscriptionId = "sub_1"; var invoice = new Invoice { - SubscriptionId = subscriptionId + Parent = new InvoiceParent + { + Type = "subscription_details", + SubscriptionDetails = new InvoiceParentSubscriptionDetails + { + SubscriptionId = subscriptionId + } + } }; - _stripeEventService.GetInvoice(stripeEvent).Returns(invoice); + _stripeEventService.GetInvoice(stripeEvent, true, Arg.Any?>()).Returns(invoice); var subscription = new Subscription { @@ -101,7 +140,10 @@ public class ProviderEventServiceTests public async Task TryRecordInvoiceLineItems_InvoiceCreated_Succeeds() { // Arrange - var stripeEvent = await StripeTestEvents.GetAsync(StripeEventType.InvoiceCreated); + var stripeEvent = new Event + { + Type = "invoice.created" + }; const string subscriptionId = "sub_1"; var providerId = Guid.NewGuid(); @@ -110,17 +152,26 @@ public class ProviderEventServiceTests { Id = "invoice_1", Number = "A", - SubscriptionId = subscriptionId, - Discount = new Discount + Parent = new InvoiceParent { - Coupon = new Coupon + Type = "subscription_details", + SubscriptionDetails = new InvoiceParentSubscriptionDetails { - PercentOff = 35 + SubscriptionId = subscriptionId } - } + }, + Discounts = [ + new Discount + { + Coupon = new Coupon + { + PercentOff = 35 + } + } + ] }; - _stripeEventService.GetInvoice(stripeEvent).Returns(invoice); + _stripeEventService.GetInvoice(stripeEvent, true, Arg.Any?>()).Returns(invoice); var subscription = new Subscription { @@ -186,7 +237,7 @@ public class ProviderEventServiceTests foreach (var providerPlan in providerPlans) { - _pricingClient.GetPlanOrThrow(providerPlan.PlanType).Returns(StaticStore.GetPlan(providerPlan.PlanType)); + _pricingClient.GetPlanOrThrow(providerPlan.PlanType).Returns(MockPlans.Get(providerPlan.PlanType)); } _providerPlanRepository.GetByProviderId(providerId).Returns(providerPlans); @@ -195,8 +246,8 @@ public class ProviderEventServiceTests await _providerEventService.TryRecordInvoiceLineItems(stripeEvent); // Assert - var teamsPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterprisePlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterprisePlan = MockPlans.Get(PlanType.EnterpriseMonthly); await _providerInvoiceItemRepository.Received(1).CreateAsync(Arg.Is( options => @@ -249,7 +300,10 @@ public class ProviderEventServiceTests public async Task TryRecordInvoiceLineItems_InvoiceFinalized_Succeeds() { // Arrange - var stripeEvent = await StripeTestEvents.GetAsync(StripeEventType.InvoiceFinalized); + var stripeEvent = new Event + { + Type = "invoice.finalized" + }; const string subscriptionId = "sub_1"; var providerId = Guid.NewGuid(); @@ -258,10 +312,17 @@ public class ProviderEventServiceTests { Id = "invoice_1", Number = "A", - SubscriptionId = subscriptionId + Parent = new InvoiceParent + { + Type = "subscription_details", + SubscriptionDetails = new InvoiceParentSubscriptionDetails + { + SubscriptionId = subscriptionId + } + }, }; - _stripeEventService.GetInvoice(stripeEvent).Returns(invoice); + _stripeEventService.GetInvoice(stripeEvent, true, Arg.Any?>()).Returns(invoice); var subscription = new Subscription { diff --git a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs index e9f0d9d0ed..a7aefe3163 100644 --- a/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs +++ b/test/Billing.Test/Services/SetupIntentSucceededHandlerTests.cs @@ -4,8 +4,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Services; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -61,7 +61,7 @@ public class SetupIntentSucceededHandlerTests // Assert await _setupIntentCache.DidNotReceiveWithAnyArgs().GetSubscriberIdForSetupIntent(Arg.Any()); - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -86,7 +86,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -116,7 +116,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == organization.GatewayCustomerId)); @@ -151,7 +151,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.Received(1).PaymentMethodAttachAsync( + await _stripeAdapter.Received(1).AttachPaymentMethodAsync( "pm_test", Arg.Is(o => o.Customer == provider.GatewayCustomerId)); @@ -183,7 +183,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); @@ -216,7 +216,7 @@ public class SetupIntentSucceededHandlerTests await _handler.HandleAsync(_mockEvent); // Assert - await _stripeAdapter.DidNotReceiveWithAnyArgs().PaymentMethodAttachAsync( + await _stripeAdapter.DidNotReceiveWithAnyArgs().AttachPaymentMethodAsync( Arg.Any(), Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); await _pushNotificationAdapter.DidNotReceiveWithAnyArgs().NotifyBankAccountVerifiedAsync(Arg.Any()); diff --git a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs index 2797b2e589..de2d3ec0ed 100644 --- a/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionDeletedHandlerTests.cs @@ -1,9 +1,15 @@ using Bit.Billing.Constants; +using Bit.Billing.Jobs; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; +using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Extensions; using Bit.Core.Services; using NSubstitute; +using Quartz; using Stripe; using Xunit; @@ -15,6 +21,10 @@ public class SubscriptionDeletedHandlerTests private readonly IUserService _userService; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ISchedulerFactory _schedulerFactory; + private readonly IScheduler _scheduler; private readonly SubscriptionDeletedHandler _sut; public SubscriptionDeletedHandlerTests() @@ -23,11 +33,19 @@ public class SubscriptionDeletedHandlerTests _userService = Substitute.For(); _stripeEventUtilityService = Substitute.For(); _organizationDisableCommand = Substitute.For(); + _providerRepository = Substitute.For(); + _providerService = Substitute.For(); + _schedulerFactory = Substitute.For(); + _scheduler = Substitute.For(); + _schedulerFactory.GetScheduler().Returns(_scheduler); _sut = new SubscriptionDeletedHandler( _stripeEventService, _userService, _stripeEventUtilityService, - _organizationDisableCommand); + _organizationDisableCommand, + _providerRepository, + _providerService, + _schedulerFactory); } [Fact] @@ -38,7 +56,13 @@ public class SubscriptionDeletedHandlerTests var subscription = new Subscription { Status = "active", - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, Metadata = new Dictionary() }; @@ -52,6 +76,7 @@ public class SubscriptionDeletedHandlerTests // Assert await _organizationDisableCommand.DidNotReceiveWithAnyArgs().DisableAsync(default, default); await _userService.DidNotReceiveWithAnyArgs().DisablePremiumAsync(default, default); + await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default); } [Fact] @@ -63,11 +88,14 @@ public class SubscriptionDeletedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Canceled, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), - Metadata = new Dictionary + Items = new StripeList { - { "organizationId", organizationId.ToString() } - } + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); @@ -79,7 +107,7 @@ public class SubscriptionDeletedHandlerTests // Assert await _organizationDisableCommand.Received(1) - .DisableAsync(organizationId, subscription.CurrentPeriodEnd); + .DisableAsync(organizationId, subscription.GetCurrentPeriodEnd()); } [Fact] @@ -91,11 +119,14 @@ public class SubscriptionDeletedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Canceled, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), - Metadata = new Dictionary + Items = new StripeList { - { "userId", userId.ToString() } - } + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "userId", userId.ToString() } } }; _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); @@ -107,7 +138,7 @@ public class SubscriptionDeletedHandlerTests // Assert await _userService.Received(1) - .DisablePremiumAsync(userId, subscription.CurrentPeriodEnd); + .DisablePremiumAsync(userId, subscription.GetCurrentPeriodEnd()); } [Fact] @@ -119,11 +150,14 @@ public class SubscriptionDeletedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Canceled, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), - Metadata = new Dictionary + Items = new StripeList { - { "organizationId", organizationId.ToString() } + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] }, + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } }, CancellationDetails = new SubscriptionCancellationDetails { Comment = "Cancelled as part of provider migration to Consolidated Billing" @@ -151,11 +185,14 @@ public class SubscriptionDeletedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Canceled, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), - Metadata = new Dictionary + Items = new StripeList { - { "organizationId", organizationId.ToString() } + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] }, + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } }, CancellationDetails = new SubscriptionCancellationDetails { Comment = "Organization was added to Provider" @@ -173,4 +210,120 @@ public class SubscriptionDeletedHandlerTests await _organizationDisableCommand.DidNotReceiveWithAnyArgs() .DisableAsync(default, default); } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_DisablesProviderAndQueuesJob() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var provider = new Provider + { + Id = providerId, + Enabled = true + }; + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await _sut.HandleAsync(stripeEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _scheduler.Received(1).ScheduleJob( + Arg.Is(j => j.JobType == typeof(ProviderOrganizationDisableJob)), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_ProviderNotFound_DoesNotThrow() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns((Provider)null); + + // Act & Assert - Should not throw + await _sut.HandleAsync(stripeEvent); + + // Assert + await _providerService.DidNotReceiveWithAnyArgs().UpdateAsync(default); + await _scheduler.DidNotReceiveWithAnyArgs().ScheduleJob(default, default); + } + + [Fact] + public async Task HandleAsync_ProviderSubscriptionCanceled_QueuesJobWithCorrectParameters() + { + // Arrange + var stripeEvent = new Event(); + var providerId = Guid.NewGuid(); + var expirationDate = DateTime.UtcNow.AddDays(30); + var provider = new Provider + { + Id = providerId, + Enabled = true + }; + var subscription = new Subscription + { + Status = StripeSubscriptionStatus.Canceled, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = expirationDate } + ] + }, + Metadata = new Dictionary { { "providerId", providerId.ToString() } } + }; + + _stripeEventService.GetSubscription(stripeEvent, true).Returns(subscription); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(Tuple.Create(null, null, providerId)); + _providerRepository.GetByIdAsync(providerId).Returns(provider); + + // Act + await _sut.HandleAsync(stripeEvent); + + // Assert + Assert.False(provider.Enabled); + await _providerService.Received(1).UpdateAsync(provider); + await _scheduler.Received(1).ScheduleJob( + Arg.Is(j => + j.JobType == typeof(ProviderOrganizationDisableJob) && + j.JobDataMap.GetString("providerId") == providerId.ToString() && + j.JobDataMap.GetString("expirationDate") == expirationDate.ToString("O")), + Arg.Is(t => t.Key.Name == $"disable-trigger-{providerId}")); + } } diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 6a7cd7704b..182f09e163 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -1,18 +1,17 @@ using Bit.Billing.Constants; using Bit.Billing.Services; using Bit.Billing.Services.Implementations; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; using NSubstitute; @@ -21,6 +20,8 @@ using Quartz; using Stripe; using Xunit; using Event = Stripe.Event; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; +using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; namespace Bit.Billing.Test.Services; @@ -96,7 +97,13 @@ public class SubscriptionUpdatedHandlerTests { Id = subscriptionId, Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "organizationId", organizationId.ToString() } }, LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } }; @@ -120,73 +127,6 @@ public class SubscriptionUpdatedHandlerTests Arg.Is(t => t.Key.Name == $"cancel-trigger-{subscriptionId}")); } - [Fact] - public async Task - HandleAsync_UnpaidProviderSubscription_WithManualSuspensionViaMetadata_DisablesProviderAndSchedulesCancellation() - { - // Arrange - var providerId = Guid.NewGuid(); - var subscriptionId = "sub_test123"; - - var previousSubscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Active, - Metadata = new Dictionary - { - ["suspend_provider"] = null // This is the key part - metadata exists, but value is null - } - }; - - var currentSubscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), - Metadata = new Dictionary - { - ["providerId"] = providerId.ToString(), - ["suspend_provider"] = "true" // Now has a value, indicating manual suspension - }, - TestClock = null - }; - - var parsedEvent = new Event - { - Id = "evt_test123", - Type = HandledStripeWebhook.SubscriptionUpdated, - Data = new EventData - { - Object = currentSubscription, - PreviousAttributes = JObject.FromObject(previousSubscription) - } - }; - - var provider = new Provider { Id = providerId, Enabled = true }; - - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); - _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); - _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId).Returns(provider); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); - - // Verify that UpdateSubscription was called with both CancelAt and the new metadata - await _stripeFacade.Received(1).UpdateSubscription( - subscriptionId, - Arg.Is(options => - options.CancelAt.HasValue && - options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && - options.Metadata != null && - options.Metadata.ContainsKey("suspended_provider_via_webhook_at"))); - } - [Fact] public async Task HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() @@ -206,7 +146,13 @@ public class SubscriptionUpdatedHandlerTests { Id = subscriptionId, Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(30), + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, Metadata = new Dictionary { ["providerId"] = providerId.ToString() }, LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }, TestClock = null @@ -225,7 +171,6 @@ public class SubscriptionUpdatedHandlerTests var provider = new Provider { Id = providerId, Enabled = true }; - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover).Returns(true); _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) .Returns(Tuple.Create(null, null, providerId)); @@ -238,13 +183,12 @@ public class SubscriptionUpdatedHandlerTests Assert.False(provider.Enabled); await _providerService.Received(1).UpdateAsync(provider); - // Verify that UpdateSubscription was called with CancelAt but WITHOUT suspension metadata + // Verify that UpdateSubscription was called with CancelAt await _stripeFacade.Received(1).UpdateSubscription( subscriptionId, Arg.Is(options => options.CancelAt.HasValue && - options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && - (options.Metadata == null || !options.Metadata.ContainsKey("suspended_provider_via_webhook_at")))); + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1))); } [Fact] @@ -257,6 +201,13 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Id = subscriptionId, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, Status = StripeSubscriptionStatus.Unpaid, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } @@ -281,9 +232,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -306,6 +254,13 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Id = subscriptionId, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, Status = StripeSubscriptionStatus.Unpaid, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } @@ -321,9 +276,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -348,7 +300,13 @@ public class SubscriptionUpdatedHandlerTests { Id = subscriptionId, Status = StripeSubscriptionStatus.IncompleteExpired, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, LatestInvoice = new Invoice { BillingReason = "renewal" } }; @@ -363,9 +321,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); @@ -378,42 +333,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } - [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WhenFeatureFlagDisabled_DoesNothing() - { - // Arrange - var providerId = Guid.NewGuid(); - var subscriptionId = "sub_123"; - var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - - var subscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = currentPeriodEnd, - Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } - }; - - var parsedEvent = new Event { Data = new EventData() }; - - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); - - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(false); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); - await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); - } - [Fact] public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_DoesNothing() { @@ -426,7 +345,13 @@ public class SubscriptionUpdatedHandlerTests { Id = subscriptionId, Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } }; @@ -439,9 +364,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) .Returns(Tuple.Create(null, null, providerId)); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); - _providerRepository.GetByIdAsync(providerId) .Returns((Provider)null); @@ -464,19 +386,91 @@ public class SubscriptionUpdatedHandlerTests { Id = subscriptionId, Status = StripeSubscriptionStatus.Unpaid, - CurrentPeriodEnd = currentPeriodEnd, Metadata = new Dictionary { { "userId", userId.ToString() } }, Items = new StripeList { Data = [ - new SubscriptionItem { Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } } + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } ] } }; var parsedEvent = new Event { Data = new EventData() }; + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List() }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + } + + [Fact] + public async Task HandleAsync_IncompleteExpiredUserSubscription_DisablesPremiumAndCancelsSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.IncompleteExpired, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); @@ -508,7 +502,13 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Active, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; @@ -552,7 +552,13 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Active, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "userId", userId.ToString() } } }; @@ -583,7 +589,13 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Status = StripeSubscriptionStatus.Active, - CurrentPeriodEnd = currentPeriodEnd, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; @@ -616,18 +628,24 @@ public class SubscriptionUpdatedHandlerTests { Id = "sub_123", Status = StripeSubscriptionStatus.Active, - CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), CustomerId = "cus_123", Items = new StripeList { - Data = [new SubscriptionItem { Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } }] + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Stripe.Plan { Id = "2023-enterprise-org-seat-annually" } + } + ] }, Customer = new Customer { Balance = 0, Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } } }, - Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } }, + Discounts = [new Discount { Coupon = new Coupon { Id = "sm-standalone" } }], Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; @@ -652,7 +670,7 @@ public class SubscriptionUpdatedHandlerTests { Data = [ - new SubscriptionItem { Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } } + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-enterprise-seat-annually" } } ] } }) @@ -700,8 +718,6 @@ public class SubscriptionUpdatedHandlerTests _stripeFacade .UpdateSubscription(Arg.Any(), Arg.Any()) .Returns(newSubscription); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -723,12 +739,8 @@ public class SubscriptionUpdatedHandlerTests .Received(1) .UpdateSubscription(newSubscription.Id, Arg.Is(options => options.CancelAtPeriodEnd == false)); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } - [Fact] public async Task HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasCanceled_EnableProvider() @@ -747,8 +759,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -767,9 +777,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -790,8 +797,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -810,9 +815,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -833,8 +835,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -853,9 +853,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -877,8 +874,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -899,9 +894,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -921,8 +913,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .ReturnsNull(); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -943,9 +933,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceive() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } [Fact] @@ -964,8 +951,6 @@ public class SubscriptionUpdatedHandlerTests _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); - _featureService.IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover) - .Returns(true); // Act await _sut.HandleAsync(parsedEvent); @@ -986,9 +971,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .DidNotReceive() .UpdateSubscription(Arg.Any()); - _featureService - .Received(1) - .IsEnabled(FeatureFlagKeys.PM21821_ProviderPortalTakeover); } private static (Guid providerId, Subscription newSubscription, Provider provider, Event parsedEvent) @@ -998,6 +980,13 @@ public class SubscriptionUpdatedHandlerTests var newSubscription = new Subscription { Id = previousSubscription?.Id ?? "sub_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } + ] + }, Status = StripeSubscriptionStatus.Active, Metadata = new Dictionary { { "providerId", providerId.ToString() } } }; @@ -1015,13 +1004,144 @@ public class SubscriptionUpdatedHandlerTests return (providerId, newSubscription, provider, parsedEvent); } + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithOpenInvoice_CancelsSubscriptionAndDisablesPremium() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var openInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Open + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = openInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + _stripeFacade.ListInvoices(Arg.Any()) + .Returns(new StripeList { Data = new List { openInvoice } }); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1) + .CancelSubscription(subscriptionId, Arg.Any()); + await _stripeFacade.Received(1) + .ListInvoices(Arg.Is(o => + o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); + await _stripeFacade.Received(1) + .VoidInvoice(openInvoice.Id); + } + + [Fact] + public async Task HandleAsync_IncompleteUserSubscriptionWithoutOpenInvoice_DoesNotCancelSubscription() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var paidInvoice = new Invoice + { + Id = "inv_123", + Status = StripeInvoiceStatus.Paid + }; + var subscription = new Subscription + { + Id = subscriptionId, + Status = StripeSubscriptionStatus.Incomplete, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = paidInvoice, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } + } + ] + } + }; + + var parsedEvent = new Event { Data = new EventData() }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } + }; + _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(null, userId, null)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.DidNotReceive() + .DisablePremiumAsync(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .CancelSubscription(Arg.Any(), Arg.Any()); + await _stripeFacade.DidNotReceive() + .ListInvoices(Arg.Any()); + } + public static IEnumerable GetNonActiveSubscriptions() { return new List { new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Unpaid } }, new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Incomplete } }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.IncompleteExpired } }, + new object[] + { + new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.IncompleteExpired } + }, new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Paused } } }; } diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs new file mode 100644 index 0000000000..3b133c7d37 --- /dev/null +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -0,0 +1,2530 @@ +using System.Globalization; +using Bit.Billing.Services; +using Bit.Billing.Services.Implementations; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; +using Bit.Core.Entities; +using Bit.Core.Models.Mail.Billing.Renewal.Families2019Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Families2020Renewal; +using Bit.Core.Models.Mail.Billing.Renewal.Premium; +using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; +using Bit.Core.Platform.Mail.Mailer; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Stripe; +using Xunit; +using static Bit.Core.Billing.Constants.StripeConstants; +using Address = Stripe.Address; +using Event = Stripe.Event; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; + +namespace Bit.Billing.Test.Services; + +public class UpcomingInvoiceHandlerTests +{ + private readonly IGetPaymentMethodQuery _getPaymentMethodQuery; + private readonly ILogger _logger; + private readonly IMailService _mailService; + private readonly IOrganizationRepository _organizationRepository; + private readonly IPricingClient _pricingClient; + private readonly IProviderRepository _providerRepository; + private readonly IStripeFacade _stripeFacade; + private readonly IStripeEventService _stripeEventService; + private readonly IStripeEventUtilityService _stripeEventUtilityService; + private readonly IUserRepository _userRepository; + private readonly IValidateSponsorshipCommand _validateSponsorshipCommand; + private readonly IMailer _mailer; + private readonly IFeatureService _featureService; + + private readonly UpcomingInvoiceHandler _sut; + + private readonly Guid _userId = Guid.NewGuid(); + private readonly Guid _organizationId = Guid.NewGuid(); + private readonly Guid _providerId = Guid.NewGuid(); + + + public UpcomingInvoiceHandlerTests() + { + _getPaymentMethodQuery = Substitute.For(); + _logger = Substitute.For>(); + _mailService = Substitute.For(); + _organizationRepository = Substitute.For(); + _pricingClient = Substitute.For(); + _providerRepository = Substitute.For(); + _stripeFacade = Substitute.For(); + _stripeEventService = Substitute.For(); + _stripeEventUtilityService = Substitute.For(); + _userRepository = Substitute.For(); + _validateSponsorshipCommand = Substitute.For(); + _mailer = Substitute.For(); + _featureService = Substitute.For(); + + _sut = new UpcomingInvoiceHandler( + _getPaymentMethodQuery, + _logger, + _mailService, + _organizationRepository, + _pricingClient, + _providerRepository, + _stripeFacade, + _stripeEventService, + _stripeEventUtilityService, + _userRepository, + _validateSponsorshipCommand, + _mailer, + _featureService); + } + + [Fact] + public async Task HandleAsync_WhenNullSubscription_DoesNothing() + { + // Arrange + var parsedEvent = new Event(); + var invoice = new Invoice { CustomerId = "cus_123" }; + var customer = new Customer { Id = "cus_123", Subscriptions = new StripeList { Data = [] } }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.DidNotReceive() + .UpdateCustomer(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenValidUser_SendsEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(customerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + + // If milestone 2 is disabled, the default email is sent + _featureService + .IsEnabled(FeatureFlagKeys.PM23341_Milestone_2) + .Returns(false); + + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userRepository.Received(1).GetByIdAsync(_userId); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task + HandleAsync_WhenUserValid_AndMilestone2Enabled_UpdatesPriceId_AndSendsUpdatedInvoiceUpcomingEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var priceSubscriptionId = "sub-1"; + var priceId = "price-id-2"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported } + }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = priceId }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(customerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _stripeFacade.UpdateSubscription( + subscription.Id, + Arg.Any()) + .Returns(subscription); + + // If milestone 2 is true, the updated invoice email is sent + _featureService + .IsEnabled(FeatureFlagKeys.PM23341_Milestone_2) + .Returns(true); + + var coupon = new Coupon { PercentOff = 20, Id = CouponIDs.Milestone2SubscriptionDiscount }; + + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userRepository.Received(1).GetByIdAsync(_userId); + await _pricingClient.Received(1).GetAvailablePremiumPlan(); + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone2SubscriptionDiscount); + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is("sub_123"), + Arg.Is(o => + o.Items[0].Id == priceSubscriptionId && + o.Items[0].Price == priceId && + o.Discounts[0].Coupon == CouponIDs.Milestone2SubscriptionDiscount && + o.ProrationBehavior == "none")); + + // Verify the updated invoice email was sent with correct price + var discountedPrice = plan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("user@example.com") && + email.Subject == "Your Bitwarden Premium renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenOrganizationHasSponsorship_SendsEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + LatestInvoiceId = "inv_latest" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + var plan = new FamiliesPlan(); + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + + _organizationRepository + .GetByIdAsync(_organizationId) + .Returns(organization); + + _pricingClient + .GetPlanOrThrow(organization.PlanType) + .Returns(plan); + + _stripeEventUtilityService + .IsSponsoredSubscription(subscription) + .Returns(true); + // Configure that this is a sponsored subscription + _stripeEventUtilityService + .IsSponsoredSubscription(subscription) + .Returns(true); + _validateSponsorshipCommand + .ValidateSponsorshipAsync(_organizationId) + .Returns(true); + + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _organizationRepository.Received(1).GetByIdAsync(_organizationId); + await _validateSponsorshipCommand.Received(1).ValidateSponsorshipAsync(_organizationId); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task + HandleAsync_WhenOrganizationHasSponsorship_ButInvalidSponsorship_RetrievesUpdatedInvoice_SendsEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [new SubscriptionItem { Price = new Price { Id = "2021-family-for-enterprise-annually" } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + LatestInvoiceId = "inv_latest" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + var plan = new FamiliesPlan(); + + var paymentMethod = new Card { Last4 = "4242", Brand = "visa" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + + _organizationRepository + .GetByIdAsync(_organizationId) + .Returns(organization); + + _pricingClient + .GetPlanOrThrow(organization.PlanType) + .Returns(plan); + + // Configure that this is not a sponsored subscription + _stripeEventUtilityService + .IsSponsoredSubscription(subscription) + .Returns(true); + + // Validate sponsorship should return false + _validateSponsorshipCommand + .ValidateSponsorshipAsync(_organizationId) + .Returns(false); + _stripeFacade + .GetInvoice(subscription.LatestInvoiceId) + .Returns(invoice); + + _getPaymentMethodQuery.Run(organization).Returns(MaskedPaymentMethod.From(paymentMethod)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _organizationRepository.Received(1).GetByIdAsync(_organizationId); + _stripeEventUtilityService.Received(1).IsSponsoredSubscription(subscription); + await _validateSponsorshipCommand.Received(1).ValidateSponsorshipAsync(_organizationId); + await _stripeFacade.Received(1).GetInvoice(Arg.Is("inv_latest")); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenValidOrganization_SendsEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [new SubscriptionItem { Price = new Price { Id = "enterprise-annually" } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + LatestInvoiceId = "inv_latest" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.EnterpriseAnnually + }; + var plan = new FamiliesPlan(); + + var paymentMethod = new Card { Last4 = "4242", Brand = "visa" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + + _organizationRepository + .GetByIdAsync(_organizationId) + .Returns(organization); + + _pricingClient + .GetPlanOrThrow(organization.PlanType) + .Returns(plan); + + _stripeEventUtilityService + .IsSponsoredSubscription(subscription) + .Returns(false); + + _stripeFacade + .GetInvoice(subscription.LatestInvoiceId) + .Returns(invoice); + + _getPaymentMethodQuery.Run(organization).Returns(MaskedPaymentMethod.From(paymentMethod)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _organizationRepository.Received(1).GetByIdAsync(_organizationId); + _stripeEventUtilityService.Received(1).IsSponsoredSubscription(subscription); + + // Should not validate sponsorship for non-sponsored subscription + await _validateSponsorshipCommand.DidNotReceive().ValidateSponsorshipAsync(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + + [Fact] + public async Task HandleAsync_WhenValidProviderSubscription_SendsEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary(), + CollectionMethod = "charge_automatically" + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "UK" }, + TaxExempt = TaxExempt.None + }; + var provider = new Provider { Id = _providerId, BillingEmail = "provider@example.com" }; + + var paymentMethod = new Card { Last4 = "4242", Brand = "visa" }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, null, _providerId)); + + _providerRepository.GetByIdAsync(_providerId).Returns(provider); + _getPaymentMethodQuery.Run(provider).Returns(MaskedPaymentMethod.From(paymentMethod)); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _providerRepository.Received(2).GetByIdAsync(_providerId); + + // Verify tax exempt was set to reverse for non-US providers + await _stripeFacade.Received(1).UpdateCustomer( + Arg.Is("cus_123"), + Arg.Is(o => o.TaxExempt == TaxExempt.Reverse)); + + // Verify automatic tax was enabled + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is("sub_123"), + Arg.Is(o => o.AutomaticTax.Enabled == true)); + + // Verify provider invoice email was sent + await _mailService.Received(1).SendProviderInvoiceUpcoming( + Arg.Is>(e => e.Contains("provider@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(s => s == subscription.CollectionMethod), + Arg.Is(b => b == true), + Arg.Is(s => s == $"{paymentMethod.Brand} ending in {paymentMethod.Last4}")); + } + + [Fact] + public async Task HandleAsync_WhenUpdateSubscriptionItemPriceIdFails_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var priceSubscriptionId = "sub-1"; + var priceId = "price-id-2"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = priceSubscriptionId, Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported } + }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = priceId }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(invoice.CustomerId, Arg.Any()).Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + + _userRepository.GetByIdAsync(_userId).Returns(user); + + _featureService + .IsEnabled(FeatureFlagKeys.PM23341_Milestone_2) + .Returns(true); + + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + + // Setup exception when updating subscription + _stripeFacade + .UpdateSubscription(Arg.Any(), Arg.Any()) + .ThrowsAsync(new Exception()); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString() + .Contains( + $"Failed to update user's ({user.Id}) subscription price id while processing event with ID {parsedEvent.Id}")), + Arg.Any(), + Arg.Any>()); + + // Verify that traditional email was sent when update fails + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + + // Verify renewal email was NOT sent + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenOrganizationNotFound_DoesNothing() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + + // Organization not found + _organizationRepository.GetByIdAsync(_organizationId).Returns((Organization)null); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _organizationRepository.Received(1).GetByIdAsync(_organizationId); + + // Verify no emails were sent + await _mailService.DidNotReceive().SendInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenZeroAmountInvoice_DoesNothing() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 0, // Zero amount due + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Free Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + + _userRepository.GetByIdAsync(_userId).Returns(user); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userRepository.Received(1).GetByIdAsync(_userId); + + // Should not + await _mailService.DidNotReceive().SendInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenUserNotFound_DoesNothing() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + + // User not found + _userRepository.GetByIdAsync(_userId).Returns((User)null); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userRepository.Received(1).GetByIdAsync(_userId); + + // Verify no emails were sent + await _mailService.DidNotReceive().SendInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenProviderNotFound_DoesNothing() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var invoice = new Invoice + { + CustomerId = "cus_123", + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = "cus_123", + Items = new StripeList(), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = "cus_123" }, + Metadata = new Dictionary() + }; + var customer = new Customer + { + Id = "cus_123", + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade + .GetCustomer(invoice.CustomerId, Arg.Any()) + .Returns(customer); + + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, null, _providerId)); + + // Provider not found + _providerRepository.GetByIdAsync(_providerId).Returns((Provider)null); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _providerRepository.Received(1).GetByIdAsync(_providerId); + + // Verify no provider emails were sent + await _mailService.DidNotReceive().SendProviderInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndFamilies2019Plan_UpdatesSubscriptionAndOrganization() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var premiumAccessItemId = "si_premium_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + new() + { + Id = premiumAccessItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 2 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == premiumAccessItemId && + o.Items[1].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndFamilies2019Plan_WithoutPremiumAccess_UpdatesSubscriptionAndOrganization() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Disabled_AndFamilies2019Plan_DoesNotUpdateSubscription() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - should not update subscription or organization when feature flag is disabled + await _stripeFacade.DidNotReceive().UpdateSubscription( + Arg.Any(), + Arg.Is(o => o.Discounts != null)); + + await _organizationRepository.DidNotReceive().ReplaceAsync( + Arg.Is(org => org.PlanType == PlanType.FamiliesAnnually)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_ButNotFamilies2019Plan_DoesNotUpdateSubscription() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() { Id = "si_pm_123", Price = new Price { Id = familiesPlan.PasswordManager.StripePlanId } } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually // Already on the new plan + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - should not update subscription when not on FamiliesAnnually2019 plan + await _stripeFacade.DidNotReceive().UpdateSubscription( + Arg.Any(), + Arg.Is(o => o.Discounts != null)); + + await _organizationRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndPasswordManagerItemNotFound_LogsWarning() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() { Id = "si_different_item", Price = new Price { Id = "different-price-id" } } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + _logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Could not find Organization's ({_organizationId}) password manager item") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Any(), + Arg.Any>()); + + // Should not update subscription or organization when password manager item not found + await _stripeFacade.DidNotReceive().UpdateSubscription( + Arg.Any(), + Arg.Is(o => o.Discounts != null)); + + await _organizationRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndUpdateFails_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Simulate update failure + _stripeFacade + .UpdateSubscription(Arg.Any(), Arg.Any()) + .ThrowsAsync(new Exception("Stripe API error")); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to align subscription concerns for Organization ({_organizationId})") && + o.ToString().Contains(parsedEvent.Type) && + o.ToString().Contains(parsedEvent.Id)), + Arg.Any(), + Arg.Any>()); + + // Should send traditional email when update fails + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + + // Verify renewal email was NOT sent + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndCouponNotFound_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns((Coupon)null); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to align subscription concerns for Organization ({_organizationId})") && + o.ToString().Contains(parsedEvent.Type) && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException && e.Message.Contains("Coupon for sending families 2019 email")), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndCouponPercentOffIsNull_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon + { + Id = CouponIDs.Milestone3SubscriptionDiscount, + PercentOff = null + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to align subscription concerns for Organization ({_organizationId})") && + o.ToString().Contains(parsedEvent.Type) && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException && e.Message.Contains("coupon.PercentOff")), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("org@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndSeatAddOnExists_DeletesItem() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 3 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 2 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == seatAddOnItemId && + o.Items[1].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndSeatAddOnWithQuantityOne_DeletesItem() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 1 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 2 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == seatAddOnItemId && + o.Items[1].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_WithPremiumAccessAndSeatAddOn_UpdatesBothItems() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + var premiumAccessItemId = "si_premium_123"; + var seatAddOnItemId = "si_seat_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2019Plan = new Families2019Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePlanId } + }, + + new() + { + Id = premiumAccessItemId, + Price = new Price { Id = families2019Plan.PasswordManager.StripePremiumAccessPlanId } + }, + + new() + { + Id = seatAddOnItemId, + Price = new Price { Id = "personal-org-seat-annually" }, + Quantity = 2 + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2019 + }; + + var coupon = new Coupon { PercentOff = 25, Id = CouponIDs.Milestone3SubscriptionDiscount }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeFacade.GetCoupon(CouponIDs.Milestone3SubscriptionDiscount).Returns(coupon); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2019).Returns(families2019Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 3 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Items[1].Id == premiumAccessItemId && + o.Items[1].Deleted == true && + o.Items[2].Id == seatAddOnItemId && + o.Items[2].Deleted == true && + o.Discounts.Count == 1 && + o.Discounts[0].Coupon == CouponIDs.Milestone3SubscriptionDiscount && + o.ProrationBehavior == ProrationBehavior.None)); + + await _stripeFacade.Received(1).GetCoupon(CouponIDs.Milestone3SubscriptionDiscount); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.BaseAnnualRenewalPrice == familiesPlan.PasswordManager.BasePrice.ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == $"{coupon.PercentOff}%" + )); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Enabled_AndFamilies2025Plan_UpdatesSubscriptionOnlyNoAddons() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2025Plan = new Families2025Plan(); + var familiesPlan = new FamiliesPlan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2025 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2025).Returns(families2025Plan); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(familiesPlan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.Received(1).UpdateSubscription( + Arg.Is(subscriptionId), + Arg.Is(o => + o.Items.Count == 1 && + o.Items[0].Id == passwordManagerItemId && + o.Items[0].Price == familiesPlan.PasswordManager.StripePlanId && + o.Discounts == null && + o.ProrationBehavior == ProrationBehavior.None)); + + await _organizationRepository.Received(1).ReplaceAsync( + Arg.Is(org => + org.Id == _organizationId && + org.PlanType == PlanType.FamiliesAnnually && + org.Plan == familiesPlan.Name && + org.UsersGetPremium == familiesPlan.UsersGetPremium && + org.Seats == familiesPlan.PasswordManager.BaseSeats)); + + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("org@example.com") && + email.Subject == "Your Bitwarden Families renewal is updating" && + email.View.MonthlyRenewalPrice == (familiesPlan.PasswordManager.BasePrice / 12).ToString("C", new CultureInfo("en-US")))); + } + + [Fact] + public async Task HandleAsync_WhenMilestone3Disabled_AndFamilies2025Plan_DoesNotUpdateSubscription() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123", Type = "invoice.upcoming" }; + var customerId = "cus_123"; + var subscriptionId = "sub_123"; + var passwordManagerItemId = "si_pm_123"; + + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 40000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + + var families2025Plan = new Families2025Plan(); + + var subscription = new Subscription + { + Id = subscriptionId, + CustomerId = customerId, + Items = new StripeList + { + Data = + [ + new() + { + Id = passwordManagerItemId, + Price = new Price { Id = families2025Plan.PasswordManager.StripePlanId } + } + ] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, + Metadata = new Dictionary() + }; + + var customer = new Customer + { + Id = customerId, + Subscriptions = new StripeList { Data = [subscription] }, + Address = new Address { Country = "US" } + }; + + var organization = new Organization + { + Id = _organizationId, + BillingEmail = "org@example.com", + PlanType = PlanType.FamiliesAnnually2025 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService + .GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(_organizationId, null, null)); + _organizationRepository.GetByIdAsync(_organizationId).Returns(organization); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually2025).Returns(families2025Plan); + _featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + _stripeEventUtilityService.IsSponsoredSubscription(subscription).Returns(false); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - should not update subscription or organization when feature flag is disabled + await _stripeFacade.DidNotReceive().UpdateSubscription( + Arg.Any(), + Arg.Any()); + + await _organizationRepository.DidNotReceive().ReplaceAsync( + Arg.Is(org => org.PlanType == PlanType.FamiliesAnnually)); + } + + #region Premium Renewal Email Tests + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndCouponNotFound_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns((Coupon)null); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException + && e.Message == $"Coupon for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} not found"), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndCouponPercentOffIsNull_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + var coupon = new Coupon + { + Id = CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = null + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is InvalidOperationException + && e.Message == $"coupon.PercentOff for sending premium renewal email id:{CouponIDs.Milestone2SubscriptionDiscount} is null"), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndValidCoupon_SendsPremiumRenewalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + var coupon = new Coupon + { + Id = CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 30 + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount).Returns(coupon); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + var expectedDiscountedPrice = plan.Seat.Price * (100 - coupon.PercentOff.Value) / 100; + await _mailer.Received(1).SendEmail( + Arg.Is(email => + email.ToEmails.Contains("user@example.com") && + email.Subject == "Your Bitwarden Premium renewal is updating" && + email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountAmount == "30%" && + email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US")) + )); + + await _mailService.DidNotReceive().SendInvoiceUpcoming( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Any()); + } + + [Fact] + public async Task HandleAsync_WhenMilestone2Enabled_AndGetCouponThrowsException_LogsErrorAndSendsTraditionalEmail() + { + // Arrange + var parsedEvent = new Event { Id = "evt_123" }; + var customerId = "cus_123"; + var invoice = new Invoice + { + CustomerId = customerId, + AmountDue = 10000, + NextPaymentAttempt = DateTime.UtcNow.AddDays(7), + Lines = new StripeList + { + Data = [new() { Description = "Test Item" }] + } + }; + var subscription = new Subscription + { + Id = "sub_123", + CustomerId = customerId, + Items = new StripeList + { + Data = [new() { Id = "si_123", Price = new Price { Id = Prices.PremiumAnnually } }] + }, + AutomaticTax = new SubscriptionAutomaticTax { Enabled = false }, + Customer = new Customer { Id = customerId }, + Metadata = new Dictionary() + }; + var user = new User { Id = _userId, Email = "user@example.com", Premium = true }; + var plan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + var customer = new Customer + { + Id = customerId, + Tax = new CustomerTax { AutomaticTax = AutomaticTaxStatus.Supported }, + Subscriptions = new StripeList { Data = [subscription] } + }; + + _stripeEventService.GetInvoice(parsedEvent).Returns(invoice); + _stripeFacade.GetCustomer(customerId, Arg.Any()).Returns(customer); + _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata) + .Returns(new Tuple(null, _userId, null)); + _userRepository.GetByIdAsync(_userId).Returns(user); + _pricingClient.GetAvailablePremiumPlan().Returns(plan); + _featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2).Returns(true); + _stripeFacade.GetCoupon(CouponIDs.Milestone2SubscriptionDiscount) + .ThrowsAsync(new StripeException("Stripe API error")); + _stripeFacade.UpdateSubscription(Arg.Any(), Arg.Any()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - Exception is caught, error is logged, and traditional email is sent + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => + o.ToString().Contains($"Failed to update user's ({user.Id}) subscription price id") && + o.ToString().Contains(parsedEvent.Id)), + Arg.Is(e => e is StripeException), + Arg.Any>()); + + await _mailer.DidNotReceive().SendEmail(Arg.Any()); + + await _mailService.Received(1).SendInvoiceUpcoming( + Arg.Is>(emails => emails.Contains("user@example.com")), + Arg.Is(amount => amount == invoice.AmountDue / 100M), + Arg.Is(dueDate => dueDate == invoice.NextPaymentAttempt.Value), + Arg.Is>(items => items.Count == invoice.Lines.Data.Count), + Arg.Is(b => b == true)); + } + + #endregion +} diff --git a/test/Billing.Test/Utilities/StripeTestEvents.cs b/test/Billing.Test/Utilities/StripeTestEvents.cs deleted file mode 100644 index 86792af812..0000000000 --- a/test/Billing.Test/Utilities/StripeTestEvents.cs +++ /dev/null @@ -1,35 +0,0 @@ -using Stripe; - -namespace Bit.Billing.Test.Utilities; - -public enum StripeEventType -{ - ChargeSucceeded, - CustomerSubscriptionUpdated, - CustomerUpdated, - InvoiceCreated, - InvoiceFinalized, - InvoiceUpcoming, - PaymentMethodAttached -} - -public static class StripeTestEvents -{ - public static async Task GetAsync(StripeEventType eventType) - { - var fileName = eventType switch - { - StripeEventType.ChargeSucceeded => "charge.succeeded.json", - StripeEventType.CustomerSubscriptionUpdated => "customer.subscription.updated.json", - StripeEventType.CustomerUpdated => "customer.updated.json", - StripeEventType.InvoiceCreated => "invoice.created.json", - StripeEventType.InvoiceFinalized => "invoice.finalized.json", - StripeEventType.InvoiceUpcoming => "invoice.upcoming.json", - StripeEventType.PaymentMethodAttached => "payment_method.attached.json" - }; - - var resource = await EmbeddedResourceReader.ReadAsync("Events", fileName); - - return EventUtility.ParseEvent(resource); - } -} diff --git a/test/Core.IntegrationTest/Core.IntegrationTest.csproj b/test/Core.IntegrationTest/Core.IntegrationTest.csproj index 21b746c2fb..133793d3d8 100644 --- a/test/Core.IntegrationTest/Core.IntegrationTest.csproj +++ b/test/Core.IntegrationTest/Core.IntegrationTest.csproj @@ -11,11 +11,11 @@ - - + + - - + + diff --git a/test/Core.IntegrationTest/MailKitSmtpMailDeliveryServiceTests.cs b/test/Core.IntegrationTest/MailKitSmtpMailDeliveryServiceTests.cs index 06f333b05c..1883036f9c 100644 --- a/test/Core.IntegrationTest/MailKitSmtpMailDeliveryServiceTests.cs +++ b/test/Core.IntegrationTest/MailKitSmtpMailDeliveryServiceTests.cs @@ -1,7 +1,7 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using Bit.Core.Models.Mail; -using Bit.Core.Services; +using Bit.Core.Platform.Mail.Delivery; using Bit.Core.Settings; using MailKit.Security; using Microsoft.Extensions.Logging; diff --git a/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs index 080b8ec62e..1c809f604d 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/CurrentContextOrganizationFixtures.cs @@ -1,4 +1,6 @@ -using AutoFixture; +using System.Reflection; +using AutoFixture; +using AutoFixture.Xunit2; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Models.Data; @@ -23,6 +25,7 @@ public class CurrentContextOrganizationCustomization : ICustomization } } +[AttributeUsage(AttributeTargets.Method)] public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribute { public Guid Id { get; set; } @@ -38,3 +41,19 @@ public class CurrentContextOrganizationCustomizeAttribute : BitCustomizeAttribut AccessSecretsManager = AccessSecretsManager }; } + +public class CurrentContextOrganizationAttribute : CustomizeAttribute +{ + public Guid Id { get; set; } + public OrganizationUserType Type { get; set; } = OrganizationUserType.User; + public Permissions Permissions { get; set; } = new(); + public bool AccessSecretsManager { get; set; } = false; + + public override ICustomization GetCustomization(ParameterInfo _) => new CurrentContextOrganizationCustomization + { + Id = Id, + Type = Type, + Permissions = Permissions, + AccessSecretsManager = AccessSecretsManager + }; +} diff --git a/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs index e906862e3f..c874fe58d8 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/OrganizationFixtures.cs @@ -1,6 +1,8 @@ -using System.Text.Json; +using System.Reflection; +using System.Text.Json; using AutoFixture; using AutoFixture.Kernel; +using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; @@ -9,7 +11,7 @@ using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; using Bit.Core.Models.Data; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.DataProtection; @@ -20,12 +22,24 @@ public class OrganizationCustomization : ICustomization { public bool UseGroups { get; set; } public PlanType PlanType { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + + public OrganizationCustomization() + { + + } + + public OrganizationCustomization(bool useAutomaticUserConfirmation, PlanType planType) + { + UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + PlanType = planType; + } public void Customize(IFixture fixture) { var organizationId = Guid.NewGuid(); var maxCollections = (short)new Random().Next(10, short.MaxValue); - var plan = StaticStore.Plans.FirstOrDefault(p => p.Type == PlanType); + var plan = MockPlans.Plans.FirstOrDefault(p => p.Type == PlanType); var seats = (short)new Random().Next(plan.PasswordManager.BaseSeats, plan.PasswordManager.MaxSeats ?? short.MaxValue); var smSeats = plan.SupportsSecretsManager ? (short?)new Random().Next(plan.SecretsManager.BaseSeats, plan.SecretsManager.MaxSeats ?? short.MaxValue) @@ -37,7 +51,8 @@ public class OrganizationCustomization : ICustomization .With(o => o.UseGroups, UseGroups) .With(o => o.PlanType, PlanType) .With(o => o.Seats, seats) - .With(o => o.SmSeats, smSeats)); + .With(o => o.SmSeats, smSeats) + .With(o => o.UseAutomaticUserConfirmation, UseAutomaticUserConfirmation)); fixture.Customize(composer => composer @@ -77,7 +92,7 @@ internal class PaidOrganization : ICustomization public PlanType CheckedPlanType { get; set; } public void Customize(IFixture fixture) { - var validUpgradePlans = StaticStore.Plans.Where(p => p.Type != PlanType.Free && p.LegacyYear == null).OrderBy(p => p.UpgradeSortOrder).Select(p => p.Type).ToList(); + var validUpgradePlans = MockPlans.Plans.Where(p => p.Type != PlanType.Free && p.LegacyYear == null).OrderBy(p => p.UpgradeSortOrder).Select(p => p.Type).ToList(); var lowestActivePaidPlan = validUpgradePlans.First(); CheckedPlanType = CheckedPlanType.Equals(PlanType.Free) ? lowestActivePaidPlan : CheckedPlanType; validUpgradePlans.Remove(lowestActivePaidPlan); @@ -105,7 +120,7 @@ internal class FreeOrganizationUpgrade : ICustomization .With(o => o.PlanType, PlanType.Free)); var plansToIgnore = new List { PlanType.Free, PlanType.Custom }; - var selectedPlan = StaticStore.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); + var selectedPlan = MockPlans.Plans.Last(p => !plansToIgnore.Contains(p.Type) && !p.Disabled); fixture.Customize(composer => composer .With(ou => ou.Plan, selectedPlan.Type) @@ -153,7 +168,7 @@ public class SecretsManagerOrganizationCustomization : ICustomization .With(o => o.Id, organizationId) .With(o => o.UseSecretsManager, true) .With(o => o.PlanType, planType) - .With(o => o.Plan, StaticStore.GetPlan(planType).Name) + .With(o => o.Plan, MockPlans.Get(planType).Name) .With(o => o.MaxAutoscaleSmSeats, (int?)null) .With(o => o.MaxAutoscaleSmServiceAccounts, (int?)null)); } @@ -277,3 +292,9 @@ internal class EphemeralDataProtectionAutoDataAttribute : CustomAutoDataAttribut public EphemeralDataProtectionAutoDataAttribute() : base(new SutProviderCustomization(), new EphemeralDataProtectionCustomization()) { } } + +internal class OrganizationAttribute(bool useAutomaticUserConfirmation = false, PlanType planType = PlanType.Free) : CustomizeAttribute +{ + public override ICustomization GetCustomization(ParameterInfo parameter) => + new OrganizationCustomization(useAutomaticUserConfirmation, planType); +} diff --git a/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs index 634b234e70..53511de550 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/OrganizationUserPolicyDetailsFixtures.cs @@ -2,6 +2,7 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Enums; +using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; namespace Bit.Core.Test.AdminConsole.AutoFixture; @@ -9,10 +10,16 @@ namespace Bit.Core.Test.AdminConsole.AutoFixture; internal class OrganizationUserPolicyDetailsCustomization : ICustomization { public PolicyType Type { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType UserType { get; set; } + public bool IsProvider { get; set; } - public OrganizationUserPolicyDetailsCustomization(PolicyType type) + public OrganizationUserPolicyDetailsCustomization(PolicyType type, OrganizationUserStatusType status, OrganizationUserType userType, bool isProvider) { Type = type; + Status = status; + UserType = userType; + IsProvider = isProvider; } public void Customize(IFixture fixture) @@ -20,6 +27,9 @@ internal class OrganizationUserPolicyDetailsCustomization : ICustomization fixture.Customize(composer => composer .With(o => o.OrganizationId, Guid.NewGuid()) .With(o => o.PolicyType, Type) + .With(o => o.OrganizationUserStatus, Status) + .With(o => o.OrganizationUserType, UserType) + .With(o => o.IsProvider, IsProvider) .With(o => o.PolicyEnabled, true)); } } @@ -27,14 +37,25 @@ internal class OrganizationUserPolicyDetailsCustomization : ICustomization public class OrganizationUserPolicyDetailsAttribute : CustomizeAttribute { private readonly PolicyType _type; + private readonly OrganizationUserStatusType _status; + private readonly OrganizationUserType _userType; + private readonly bool _isProvider; - public OrganizationUserPolicyDetailsAttribute(PolicyType type) + public OrganizationUserPolicyDetailsAttribute(PolicyType type) : this(type, OrganizationUserStatusType.Accepted, OrganizationUserType.User, false) { _type = type; } + public OrganizationUserPolicyDetailsAttribute(PolicyType type, OrganizationUserStatusType status, OrganizationUserType userType, bool isProvider) + { + _type = type; + _status = status; + _userType = userType; + _isProvider = isProvider; + } + public override ICustomization GetCustomization(ParameterInfo parameter) { - return new OrganizationUserPolicyDetailsCustomization(_type); + return new OrganizationUserPolicyDetailsCustomization(_type, _status, _userType, _isProvider); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs new file mode 100644 index 0000000000..88025301b6 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs @@ -0,0 +1,296 @@ +using AutoFixture; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Identity; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.AccountRecovery; + +[SutProviderCustomize] +public class AdminRecoverAccountCommandTests +{ + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_Success( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + User user, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + SetupValidUser(sutProvider, user, organizationUser); + SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword); + + // Act + var result = await sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key); + + // Assert + Assert.True(result.Succeeded); + await AssertSuccessAsync(sutProvider, user, key, organization, organizationUser); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_OrganizationDoesNotExist_ThrowsBadRequest( + [OrganizationUser] OrganizationUser organizationUser, + string newMasterPassword, + string key, + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + sutProvider.GetDependency() + .GetByIdAsync(orgId) + .Returns((Organization)null); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(orgId, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization does not allow password reset.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_OrganizationDoesNotAllowResetPassword_ThrowsBadRequest( + string newMasterPassword, + string key, + Organization organization, + [OrganizationUser] OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + organization.UseResetPassword = false; + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization does not allow password reset.", exception.Message); + } + + public static IEnumerable InvalidPolicies => new object[][] + { + [new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null] + }; + + [Theory] + [BitMemberAutoData(nameof(InvalidPolicies))] + public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest( + Policy resetPasswordPolicy, + string newMasterPassword, + string key, + Organization organization, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) + .Returns(resetPasswordPolicy); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, new OrganizationUser { Id = Guid.NewGuid() }, + newMasterPassword, key)); + Assert.Equal("Organization does not have the password reset policy enabled.", exception.Message); + } + + public static IEnumerable InvalidOrganizationUsers() + { + // Make an organization so we can use its Id + var organization = new Fixture().Create(); + + var nonConfirmed = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = organization.Id, + Status = OrganizationUserStatusType.Invited + }; + yield return [nonConfirmed, organization]; + + var wrongOrganization = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = Guid.NewGuid(), // Different org + ResetPasswordKey = "test-key", + UserId = Guid.NewGuid(), + }; + yield return [wrongOrganization, organization]; + + var nullResetPasswordKey = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = null, + UserId = Guid.NewGuid(), + }; + yield return [nullResetPasswordKey, organization]; + + var emptyResetPasswordKey = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = "", + UserId = Guid.NewGuid(), + }; + yield return [emptyResetPasswordKey, organization]; + + var nullUserId = new OrganizationUser + { + Status = OrganizationUserStatusType.Confirmed, + OrganizationId = organization.Id, + ResetPasswordKey = "test-key", + UserId = null, + }; + yield return [nullUserId, organization]; + } + + [Theory] + [BitMemberAutoData(nameof(InvalidOrganizationUsers))] + public async Task RecoverAccountAsync_OrganizationUserIsInvalid_ThrowsBadRequest( + OrganizationUser organizationUser, + Organization organization, + string newMasterPassword, + string key, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Organization User not valid", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_UserDoesNotExist_ThrowsNotFoundException( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + sutProvider.GetDependency() + .GetUserByIdAsync(organizationUser.UserId!.Value) + .Returns((User)null); + + // Act & Assert + await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + } + + [Theory] + [BitAutoData] + public async Task RecoverAccountAsync_UserUsesKeyConnector_ThrowsBadRequest( + string newMasterPassword, + string key, + Organization organization, + OrganizationUser organizationUser, + User user, + SutProvider sutProvider) + { + // Arrange + SetupValidOrganization(sutProvider, organization); + SetupValidPolicy(sutProvider, organization); + SetupValidOrganizationUser(organizationUser, organization.Id); + user.UsesKeyConnector = true; + sutProvider.GetDependency() + .GetUserByIdAsync(organizationUser.UserId!.Value) + .Returns(user); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RecoverAccountAsync(organization.Id, organizationUser, newMasterPassword, key)); + Assert.Equal("Cannot reset password of a user with Key Connector.", exception.Message); + } + + private static void SetupValidOrganization(SutProvider sutProvider, Organization organization) + { + organization.UseResetPassword = true; + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + } + + private static void SetupValidPolicy(SutProvider sutProvider, Organization organization) + { + var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true }; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) + .Returns(policy); + } + + private static void SetupValidOrganizationUser(OrganizationUser organizationUser, Guid orgId) + { + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organizationUser.OrganizationId = orgId; + organizationUser.ResetPasswordKey = "test-key"; + organizationUser.Type = OrganizationUserType.User; + } + + private static void SetupValidUser(SutProvider sutProvider, User user, OrganizationUser organizationUser) + { + user.Id = organizationUser.UserId!.Value; + user.UsesKeyConnector = false; + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + } + + private static void SetupSuccessfulPasswordUpdate(SutProvider sutProvider, User user, string newMasterPassword) + { + sutProvider.GetDependency() + .UpdatePasswordHash(user, newMasterPassword) + .Returns(IdentityResult.Success); + } + + private static async Task AssertSuccessAsync(SutProvider sutProvider, User user, string key, + Organization organization, OrganizationUser organizationUser) + { + await sutProvider.GetDependency().Received(1).ReplaceAsync( + Arg.Is(u => + u.Id == user.Id && + u.Key == key && + u.ForcePasswordReset == true && + u.RevisionDate == u.AccountRevisionDate && + u.LastPasswordChangeDate == u.RevisionDate)); + + await sutProvider.GetDependency().Received(1).SendAdminResetPasswordEmailAsync( + Arg.Is(user.Email), + Arg.Is(user.Name), + Arg.Is(organization.DisplayName())); + + await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync( + Arg.Is(organizationUser), + Arg.Is(EventType.OrganizationUser_AdminResetPassword)); + + await sutProvider.GetDependency().Received(1).PushLogOutAsync( + Arg.Is(user.Id)); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs index 933bcbc3a1..efcd57b6ad 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Import/ImportOrganizationUsersAndGroupsCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.Import; using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -57,7 +58,7 @@ public class ImportOrganizationUsersAndGroupsCommandTests var organizationUserRepository = sutProvider.GetDependency(); SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository); - sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); + sutProvider.GetDependency().HasSecretsManagerStandalone(org).Returns(true); sutProvider.GetDependency().GetManyDetailsByOrganizationAsync(org.Id).Returns(existingUsers); sutProvider.GetDependency().GetOccupiedSeatCountByOrganizationIdAsync(org.Id).Returns( new OrganizationSeatCounts diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs index b0774927e3..ef4c2c941e 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs @@ -2,8 +2,8 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -182,15 +182,42 @@ public class VerifyOrganizationDomainCommandTests _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) - .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && - x.OrganizationId == domain.OrganizationId && - x.Enabled && + .SaveAsync(Arg.Is(x => x.PolicyUpdate.Type == PolicyType.SingleOrg && + x.PolicyUpdate.OrganizationId == domain.OrganizationId && + x.PolicyUpdate.Enabled && x.PerformedBy is StandardUser && x.PerformedBy.UserId == userId)); } + [Theory, BitAutoData] + public async Task UserVerifyOrganizationDomainAsync_UsesVNextSavePolicyCommand( + OrganizationDomain domain, Guid userId, SutProvider sutProvider) + { + sutProvider.GetDependency() + .GetClaimedDomainsByDomainNameAsync(domain.DomainName) + .Returns([]); + + sutProvider.GetDependency() + .ResolveAsync(domain.DomainName, domain.Txt) + .Returns(true); + + sutProvider.GetDependency() + .UserId.Returns(userId); + + _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); + + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.Type == PolicyType.SingleOrg && + m.PolicyUpdate.OrganizationId == domain.OrganizationId && + m.PolicyUpdate.Enabled && + m.PerformedBy is StandardUser && + m.PerformedBy.UserId == userId)); + } + [Theory, BitAutoData] public async Task UserVerifyOrganizationDomainAsync_WhenDomainIsNotVerified_ThenSingleOrgPolicyShouldNotBeEnabled( OrganizationDomain domain, SutProvider sutProvider) @@ -208,9 +235,9 @@ public class VerifyOrganizationDomainCommandTests _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .SaveAsync(Arg.Any()); + .SaveAsync(Arg.Any()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs index 540bac4d1c..82d4eceaed 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AcceptOrgUserCommandTests.cs @@ -1,7 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; @@ -24,6 +26,7 @@ using Bit.Test.Common.Fakes; using Microsoft.AspNetCore.DataProtection; using NSubstitute; using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.Test.OrganizationFeatures.OrganizationUsers; @@ -673,6 +676,79 @@ public class AcceptOrgUserCommandTests Assert.Equal("User not found within organization.", exception.Message); } + // Auto-confirm policy validation tests -------------------------------------------------------------------------- + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithAutoConfirmIsNotEnabled_DoesNotCheckCompliance( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + // Act + var resultOrgUser = await sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService); + + // Assert + AssertValidAcceptedOrgUser(resultOrgUser, orgUser, user); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .IsCompliantAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithUserThatIsCompliantWithAutoConfirm_AcceptsUser( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + // Mock auto-confirm enforcement query to return valid (no auto-confirm restrictions) + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user))); + + // Act + var resultOrgUser = await sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService); + + // Assert + AssertValidAcceptedOrgUser(resultOrgUser, orgUser, user); + + await sutProvider.GetDependency().Received(1).ReplaceAsync( + Arg.Is(ou => ou.Id == orgUser.Id && ou.Status == OrganizationUserStatusType.Accepted)); + } + + [Theory] + [BitAutoData] + public async Task AcceptOrgUserAsync_WithAutoConfirmIsEnabledAndFailsCompliance_ThrowsBadRequestException( + SutProvider sutProvider, + User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails, + OrganizationUser otherOrgUser) + { + // Arrange + SetupCommonAcceptOrgUserMocks(sutProvider, user, org, orgUser, adminUserDetails); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.AcceptOrgUserAsync(orgUser, user, _userService)); + + // Should get auto-confirm error + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + } + // Private helpers ------------------------------------------------------------------------------------------------- /// @@ -716,7 +792,7 @@ public class AcceptOrgUserCommandTests /// - Provides mock data for an admin to validate email functionality. /// - Returns the corresponding organization for the given org ID. /// - private void SetupCommonAcceptOrgUserMocks(SutProvider sutProvider, User user, + private static void SetupCommonAcceptOrgUserMocks(SutProvider sutProvider, User user, Organization org, OrganizationUser orgUser, OrganizationUserUserDetails adminUserDetails) { @@ -729,18 +805,12 @@ public class AcceptOrgUserCommandTests // User is not part of any other orgs sutProvider.GetDependency() .GetManyByUserAsync(user.Id) - .Returns( - Task.FromResult>(new List()) - ); + .Returns([]); // Org they are trying to join does not have single org policy sutProvider.GetDependency() .GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg, OrganizationUserStatusType.Invited) - .Returns( - Task.FromResult>( - new List() - ) - ); + .Returns([]); // User is not part of any organization that applies the single org policy sutProvider.GetDependency() @@ -750,20 +820,24 @@ public class AcceptOrgUserCommandTests // Org does not require 2FA sutProvider.GetDependency().GetPoliciesApplicableToUserAsync(user.Id, PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited) - .Returns(Task.FromResult>( - new List())); + .Returns([]); // Provide at least 1 admin to test email functionality sutProvider.GetDependency() .GetManyByMinimumRoleAsync(orgUser.OrganizationId, OrganizationUserType.Admin) - .Returns(Task.FromResult>( - new List() { adminUserDetails } - )); + .Returns([adminUserDetails]); // Return org sutProvider.GetDependency() .GetByIdAsync(org.Id) - .Returns(Task.FromResult(org)); + .Returns(org); + + // Auto-confirm enforcement query returns valid by default (no restrictions) + var request = new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user); + + sutProvider.GetDependency() + .IsCompliantAsync(request) + .Returns(Valid(request)); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs new file mode 100644 index 0000000000..c3fb52ecbe --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs @@ -0,0 +1,639 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Core.Test.AutoFixture.OrganizationFixtures; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUsers; + +[SutProviderCustomize] +public class AutomaticallyConfirmOrganizationUsersValidatorTests +{ + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullOrganizationUser_ReturnsUserNotFoundError( + SutProvider sutProvider, + Organization organization) + { + // Arrange + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = null, + OrganizationUserId = Guid.NewGuid(), + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullUserId_ReturnsUserNotFoundError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = null; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNullOrganization_ReturnsOrganizationNotFoundError( + SutProvider sutProvider, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = null, + OrganizationId = organizationUser.OrganizationId, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithValidAcceptedUser_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true, planType: PlanType.EnterpriseAnnually)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + Assert.Equal(request, result.Request); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMismatchedOrganizationId_ReturnsOrganizationUserIdIsInvalidError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = Guid.NewGuid(); // Different from organization.Id + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData(OrganizationUserStatusType.Invited)] + [BitAutoData(OrganizationUserStatusType.Revoked)] + [BitAutoData(OrganizationUserStatusType.Confirmed)] + public async Task ValidateAsync_WithNotAcceptedStatus_ReturnsUserIsNotAcceptedError( + OrganizationUserStatusType statusType, + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = statusType; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData(OrganizationUserType.Owner)] + [BitAutoData(OrganizationUserType.Custom)] + [BitAutoData(OrganizationUserType.Admin)] + public async Task ValidateAsync_WithNonUserType_ReturnsUserIsNotUserTypeError( + OrganizationUserType userType, + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Type = userType; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWithout2FA_And2FARequired_ReturnsError( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + var twoFactorPolicyDetails = new PolicyDetails + { + OrganizationId = organization.Id, + PolicyType = PolicyType.TwoFactorAuthentication + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, false)]); + + sutProvider.GetDependency() + .GetAsync(userId) + .Returns(new RequireTwoFactorPolicyRequirement([twoFactorPolicyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWith2FA_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserWithout2FA_And2FANotRequired_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, false)]); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new RequireTwoFactorPolicyRequirement([])); // No 2FA policy + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_UserInSingleOrg_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); // Single org + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithAutoConfirmPolicyDisabled_ReturnsAutoConfirmPolicyNotEnabledError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns((Policy)null); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithOrganizationUseAutomaticUserConfirmationDisabled_ReturnsAutoConfirmPolicyNotEnabledError( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: false)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + Guid userId, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(userId, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithNonProviderUser_ReturnsValidResult( + SutProvider sutProvider, + [Organization(useAutomaticUserConfirmation: true)] Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = Substitute.For(), + DefaultUserCollectionName = "test-collection", + OrganizationUser = organizationUser, + OrganizationUserId = organizationUser.Id, + Organization = organization, + OrganizationId = organization.Id, + Key = "test-key" + }; + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(autoConfirmPolicy); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns([(user.Id, true)]); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetUserByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid( + new AutomaticUserConfirmationPolicyEnforcementRequest(organization.Id, + [organizationUser], + user))); + + + // Act + var result = await sutProvider.Sut.ValidateAsync(request); + + // Assert + Assert.True(result.IsValid); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs new file mode 100644 index 0000000000..1035d5c578 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs @@ -0,0 +1,730 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.Models.Data.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUsers; + +[SutProviderCustomize] +public class AutomaticallyConfirmUsersCommandTests +{ + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithValidRequest_ConfirmsUserSuccessfully( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)); + + await AssertSuccessfulOperationsAsync(sutProvider, organizationUser, organization, user, key); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithInvalidUserOrgId_ReturnsOrganizationUserIdIsInvalidError( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = Guid.NewGuid(); // User belongs to another organization + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, false, new OrganizationUserIdIsInvalid()); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + + await sutProvider.GetDependency() + .DidNotReceive() + .ConfirmOrganizationUserAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenAlreadyConfirmed_ReturnsNoneSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + // Return false to indicate the user is already confirmed + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(x => + x.OrganizationUserId == organizationUser.Id && x.Key == request.Key)) + .Returns(false); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .ConfirmOrganizationUserAsync(Arg.Is(x => + x.OrganizationUserId == organizationUser.Id && x.Key == request.Key)); + + // Verify no side effects occurred + await sutProvider.GetDependency() + .DidNotReceive() + .LogOrganizationUserEventAsync(Arg.Any(), Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceive() + .PushSyncOrgKeysAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectionEnabled_CreatesDefaultCollection( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, // Non-empty to trigger creation + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, true); // Policy requires collection + + sutProvider.GetDependency().ConfirmOrganizationUserAsync( + Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync( + Arg.Is(c => + c.OrganizationId == organization.Id && + c.Name == defaultCollectionName && + c.Type == CollectionType.DefaultUserCollection), + Arg.Is>(groups => groups == null), + Arg.Is>(access => + access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null)); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDefaultCollectionDisabled_DoesNotCreateCollection( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = string.Empty, // Empty, so the collection won't be created + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, false); // Policy doesn't require + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .DidNotReceive() + .CreateAsync(Arg.Any(), + Arg.Any>(), + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenCreateDefaultCollectionFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, // Non-empty to trigger creation + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + SetupPolicyRequirementMock(sutProvider, user.Id, organization.Id, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)).Returns(true); + + var collectionException = new Exception("Collection creation failed"); + sutProvider.GetDependency() + .CreateAsync(Arg.Any(), + Arg.Any>(), + Arg.Any>()) + .ThrowsAsync(collectionException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if collection creation fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to create default collection for user")), + collectionException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenEventLogFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var eventException = new Exception("Event logging failed"); + sutProvider.GetDependency() + .LogOrganizationUserEventAsync(Arg.Any(), + EventType.OrganizationUser_AutomaticallyConfirmed, + Arg.Any()) + .ThrowsAsync(eventException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if event log fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to log OrganizationUser_AutomaticallyConfirmed event")), + eventException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenSendEmailFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var emailException = new Exception("Email sending failed"); + sutProvider.GetDependency() + .SendOrganizationConfirmedEmailAsync(organization.Name, user.Email, organizationUser.AccessSecretsManager) + .ThrowsAsync(emailException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if email fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to send OrganizationUserConfirmed")), + emailException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenUserNotFoundForEmail_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + // Return null when retrieving user for email + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns((User)null!); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if user not found for email + Assert.True(result.IsSuccess); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenDeleteDeviceRegistrationFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName, + Device device) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + device.UserId = user.Id; + device.PushToken = "test-push-token"; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List { device }); + + var deviceException = new Exception("Device registration deletion failed"); + sutProvider.GetDependency() + .DeleteUserRegistrationOrganizationAsync(Arg.Any>(), organization.Id.ToString()) + .ThrowsAsync(deviceException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if device registration deletion fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to delete device registration")), + deviceException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WhenPushSyncOrgKeysFails_LogsErrorButReturnsSuccess( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + var pushException = new Exception("Push sync failed"); + sutProvider.GetDependency() + .PushSyncOrgKeysAsync(user.Id) + .ThrowsAsync(pushException); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert - side effects are fire-and-forget, so command returns success even if push sync fails + Assert.True(result.IsSuccess); + + sutProvider.GetDependency>() + .Received(1) + .Log( + LogLevel.Error, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Failed to push organization keys")), + pushException, + Arg.Any>()); + } + + [Theory] + [BitAutoData] + public async Task AutomaticallyConfirmOrganizationUserAsync_WithDevicesWithoutPushToken_FiltersCorrectly( + SutProvider sutProvider, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, + User user, + Guid performingUserId, + string key, + string defaultCollectionName, + Device deviceWithToken, + Device deviceWithoutToken) + { + // Arrange + organizationUser.UserId = user.Id; + organizationUser.OrganizationId = organization.Id; + deviceWithToken.UserId = user.Id; + deviceWithToken.PushToken = "test-token"; + deviceWithoutToken.UserId = user.Id; + deviceWithoutToken.PushToken = null; + var request = new AutomaticallyConfirmOrganizationUserRequest + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organization.Id, + Key = key, + DefaultUserCollectionName = defaultCollectionName, + PerformedBy = new StandardUser(performingUserId, true) + }; + + SetupRepositoryMocks(sutProvider, organizationUser, organization, user); + SetupValidatorMock(sutProvider, request, organizationUser, organization, true); + + sutProvider.GetDependency() + .ConfirmOrganizationUserAsync(Arg.Is(o => + o.OrganizationUserId == organizationUser.Id && o.Key == request.Key)) + .Returns(true); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List { deviceWithToken, deviceWithoutToken }); + + // Act + var result = await sutProvider.Sut.AutomaticallyConfirmOrganizationUserAsync(request); + + // Assert + Assert.True(result.IsSuccess); + + await sutProvider.GetDependency() + .Received(1) + .DeleteUserRegistrationOrganizationAsync( + Arg.Is>(devices => + devices.Count(d => deviceWithToken.Id.ToString() == d) == 1), + organization.Id.ToString()); + } + + private static void SetupRepositoryMocks( + SutProvider sutProvider, + OrganizationUser organizationUser, + Organization organization, + User user) + { + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByIdAsync(user.Id) + .Returns(user); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(user.Id) + .Returns(new List()); + } + + private static void SetupValidatorMock( + SutProvider sutProvider, + AutomaticallyConfirmOrganizationUserRequest originalRequest, + OrganizationUser organizationUser, + Organization organization, + bool isValid, + Error? error = null) + { + var validationRequest = new AutomaticallyConfirmOrganizationUserValidationRequest + { + PerformedBy = originalRequest.PerformedBy, + DefaultUserCollectionName = originalRequest.DefaultUserCollectionName, + OrganizationUserId = originalRequest.OrganizationUserId, + OrganizationUser = organizationUser, + OrganizationId = originalRequest.OrganizationId, + Organization = organization, + Key = originalRequest.Key + }; + + var validationResult = isValid + ? ValidationResultHelpers.Valid(validationRequest) + : ValidationResultHelpers.Invalid(validationRequest, error ?? new UserIsNotAccepted()); + + sutProvider.GetDependency() + .ValidateAsync(Arg.Any()) + .Returns(validationResult); + } + + private static void SetupPolicyRequirementMock( + SutProvider sutProvider, + Guid userId, + Guid organizationId, + bool requiresDefaultCollection) + { + var policyDetails = requiresDefaultCollection + ? new List { new() { OrganizationId = organizationId } } + : new List(); + + var policyRequirement = new OrganizationDataOwnershipPolicyRequirement( + requiresDefaultCollection ? OrganizationDataOwnershipState.Enabled : OrganizationDataOwnershipState.Disabled, + policyDetails); + + sutProvider.GetDependency() + .GetAsync(userId) + .Returns(policyRequirement); + } + + private static async Task AssertSuccessfulOperationsAsync( + SutProvider sutProvider, + OrganizationUser organizationUser, + Organization organization, + User user, + string key) + { + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync( + Arg.Is(x => x.Id == organizationUser.Id), + EventType.OrganizationUser_AutomaticallyConfirmed, + Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationConfirmedEmailAsync( + organization.Name, + user.Email, + organizationUser.AccessSecretsManager); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(user.Id); + + await sutProvider.GetDependency() + .Received(1) + .DeleteUserRegistrationOrganizationAsync( + Arg.Any>(), + organization.Id.ToString()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index 76b8bd92f7..5528ecb2a2 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -2,7 +2,9 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -21,6 +23,7 @@ using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; +using static Bit.Core.AdminConsole.Utilities.v2.Validation.ValidationResultHelpers; namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers; @@ -97,6 +100,8 @@ public class ConfirmOrganizationUserCommandTests [BitAutoData(PlanType.EnterpriseMonthly2019, OrganizationUserType.Owner)] [BitAutoData(PlanType.FamiliesAnnually, OrganizationUserType.Admin)] [BitAutoData(PlanType.FamiliesAnnually, OrganizationUserType.Owner)] + [BitAutoData(PlanType.FamiliesAnnually2025, OrganizationUserType.Admin)] + [BitAutoData(PlanType.FamiliesAnnually2025, OrganizationUserType.Owner)] [BitAutoData(PlanType.FamiliesAnnually2019, OrganizationUserType.Admin)] [BitAutoData(PlanType.FamiliesAnnually2019, OrganizationUserType.Owner)] [BitAutoData(PlanType.TeamsAnnually, OrganizationUserType.Admin)] @@ -557,4 +562,256 @@ public class ConfirmOrganizationUserCommandTests .DidNotReceive() .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledAndUserBelongsToAnotherOrg_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, string key, SutProvider sutProvider) + { + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); // Different org + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(orgUser.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledForOtherOrg_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new OtherOrganizationDoesNotAllowOtherMembership())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new OtherOrganizationDoesNotAllowOtherMembership().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmEnabledAndUserIsProvider_ThrowsBadRequest( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user), + new ProviderUsersCannotJoin())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new ProviderUsersCannotJoin().Message, exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmNotApplicable_Succeeds( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser], user))); + + // Act + await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id); + + // Assert + await sutProvider.GetDependency() + .Received(1).LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Confirmed); + await sutProvider.GetDependency() + .Received(1).SendOrganizationConfirmedEmailAsync(org.DisplayName(), user.Email, orgUser.AccessSecretsManager); + } + + [Theory, BitAutoData] + public async Task ConfirmUserAsync_WithAutoConfirmValidationBeforeSingleOrgPolicy_ChecksAutoConfirmFirst( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + OrganizationUser otherOrgUser, + [OrganizationUserPolicyDetails(PolicyType.SingleOrg)] OrganizationUserPolicyDetails singleOrgPolicy, + string key, SutProvider sutProvider) + { + // Arrange - Setup conditions that would fail BOTH auto-confirm AND single org policy + org.PlanType = PlanType.EnterpriseAnnually; + orgUser.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([orgUser]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser, otherOrgUser]); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency().GetManyAsync([]).ReturnsForAnyArgs([user]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + singleOrgPolicy.OrganizationId = org.Id; + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(user.Id, PolicyType.SingleOrg) + .Returns([singleOrgPolicy]); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Any()) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser, otherOrgUser], user), + new UserCannotBelongToAnotherOrganization())); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id)); + + Assert.Equal(new UserCannotBelongToAnotherOrganization().Message, exception.Message); + Assert.NotEqual("Cannot confirm this member to the organization until they leave or remove all other organizations.", + exception.Message); + } + + [Theory, BitAutoData] + public async Task ConfirmUsersAsync_WithAutoConfirmEnabled_MixedResults( + Organization org, OrganizationUser confirmingUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser2, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser3, + OrganizationUser otherOrgUser, User user1, User user2, User user3, + string key, SutProvider sutProvider) + { + // Arrange + org.PlanType = PlanType.EnterpriseAnnually; + orgUser1.OrganizationId = orgUser2.OrganizationId = orgUser3.OrganizationId = confirmingUser.OrganizationId = org.Id; + orgUser1.UserId = user1.Id; + orgUser2.UserId = user2.Id; + orgUser3.UserId = user3.Id; + otherOrgUser.UserId = user3.Id; + otherOrgUser.OrganizationId = Guid.NewGuid(); + + var orgUsers = new[] { orgUser1, orgUser2, orgUser3 }; + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs(orgUsers); + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + sutProvider.GetDependency() + .GetManyAsync([]).ReturnsForAnyArgs([user1, user2, user3]); + sutProvider.GetDependency() + .GetManyByManyUsersAsync([]) + .ReturnsForAnyArgs([orgUser1, orgUser2, orgUser3, otherOrgUser]); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers) + .Returns(true); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user1.Id)) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser1], user1))); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user2.Id)) + .Returns(Valid(new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser2], user2))); + + sutProvider.GetDependency() + .IsCompliantAsync(Arg.Is(r => r.User.Id == user3.Id)) + .Returns(Invalid( + new AutomaticUserConfirmationPolicyEnforcementRequest(org.Id, [orgUser3, otherOrgUser], user3), + new OtherOrganizationDoesNotAllowOtherMembership())); + + var keys = orgUsers.ToDictionary(ou => ou.Id, _ => key); + + // Act + var result = await sutProvider.Sut.ConfirmUsersAsync(confirmingUser.OrganizationId, keys, confirmingUser.Id); + + // Assert + Assert.Equal(3, result.Count); + Assert.Empty(result[0].Item2); + Assert.Empty(result[1].Item2); + Assert.Equal(new OtherOrganizationDoesNotAllowOtherMembership().Message, result[2].Item2); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs index c223520a04..dfb1b35be0 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/DeleteClaimedAccountvNext/DeleteClaimedOrganizationUserAccountCommandTests.cs @@ -1,5 +1,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.Utilities.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs new file mode 100644 index 0000000000..caae3a3b12 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/BulkResendOrganizationInvitesCommandTests.cs @@ -0,0 +1,113 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; + +[SutProviderCustomize] +public class BulkResendOrganizationInvitesCommandTests +{ + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_ValidatesUsersAndSendsBatchInvite( + Organization organization, + OrganizationUser validUser1, + OrganizationUser validUser2, + OrganizationUser acceptedUser, + OrganizationUser wrongOrgUser, + SutProvider sutProvider) + { + validUser1.OrganizationId = organization.Id; + validUser1.Status = OrganizationUserStatusType.Invited; + validUser2.OrganizationId = organization.Id; + validUser2.Status = OrganizationUserStatusType.Invited; + acceptedUser.OrganizationId = organization.Id; + acceptedUser.Status = OrganizationUserStatusType.Accepted; + wrongOrgUser.OrganizationId = Guid.NewGuid(); + wrongOrgUser.Status = OrganizationUserStatusType.Invited; + + var users = new List { validUser1, validUser2, acceptedUser, wrongOrgUser }; + var userIds = users.Select(u => u.Id).ToList(); + + sutProvider.GetDependency().GetManyAsync(userIds).Returns(users); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = (await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, userIds)).ToList(); + + Assert.Equal(4, result.Count); + Assert.Equal(2, result.Count(r => string.IsNullOrEmpty(r.Item2))); + Assert.Equal(2, result.Count(r => r.Item2 == "User invalid.")); + + await sutProvider.GetDependency() + .Received(1) + .SendInvitesAsync(Arg.Is(req => + req.Organization == organization && + req.Users.Length == 2 && + req.InitOrganization == false)); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_AllInvalidUsers_DoesNotSendInvites( + Organization organization, + List organizationUsers, + SutProvider sutProvider) + { + foreach (var user in organizationUsers) + { + user.OrganizationId = organization.Id; + user.Status = OrganizationUserStatusType.Confirmed; + } + + var userIds = organizationUsers.Select(u => u.Id).ToList(); + sutProvider.GetDependency().GetManyAsync(userIds).Returns(organizationUsers); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = (await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, userIds)).ToList(); + + Assert.Equal(organizationUsers.Count, result.Count); + Assert.All(result, r => Assert.Equal("User invalid.", r.Item2)); + await sutProvider.GetDependency().DidNotReceive() + .SendInvitesAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_OrganizationNotFound_ThrowsNotFoundException( + Guid organizationId, + List userIds, + List organizationUsers, + SutProvider sutProvider) + { + sutProvider.GetDependency().GetManyAsync(userIds).Returns(organizationUsers); + sutProvider.GetDependency().GetByIdAsync(organizationId).Returns((Organization?)null); + + await Assert.ThrowsAsync(() => + sutProvider.Sut.BulkResendInvitesAsync(organizationId, null, userIds)); + } + + [Theory] + [BitAutoData] + public async Task BulkResendInvitesAsync_EmptyUserList_ReturnsEmpty( + Organization organization, + SutProvider sutProvider) + { + var emptyUserIds = new List(); + sutProvider.GetDependency().GetManyAsync(emptyUserIds).Returns(new List()); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var result = await sutProvider.Sut.BulkResendInvitesAsync(organization.Id, null, emptyUserIds); + + Assert.Empty(result); + await sutProvider.GetDependency().DidNotReceive() + .SendInvitesAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs index 10dcff9e2a..5d82f0717d 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/InviteOrganizationUserCommandTests.cs @@ -13,7 +13,6 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.Commands; using Bit.Core.AdminConsole.Utilities.Errors; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -22,6 +21,7 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Time.Testing; @@ -29,6 +29,7 @@ using NSubstitute; using NSubstitute.ExceptionExtensions; using Xunit; using static Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Helpers.InviteUserOrganizationValidationRequestHelpers; +using Enterprise2019Plan = Bit.Core.Test.Billing.Mocks.Plans.Enterprise2019Plan; namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs index a5b220b94a..e26d9ce978 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteOrganizationUsersValidatorTests.cs @@ -3,12 +3,12 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions.Interface; using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -50,7 +50,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -96,7 +96,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 9 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); @@ -140,7 +140,7 @@ public class InviteOrganizationUsersValidatorTests OccupiedSmSeats = 4 }; - sutProvider.GetDependency() + sutProvider.GetDependency() .HasSecretsManagerStandalone(request.InviteOrganization) .Returns(true); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs index be5586f8a6..482b369780 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserOrganizationValidationTests.cs @@ -2,7 +2,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.Organization; using Bit.Core.AdminConsole.Utilities.Validation; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs index 738ae71298..72a146205b 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/InviteUserPaymentValidationTests.cs @@ -5,7 +5,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.V using Bit.Core.AdminConsole.Utilities.Validation; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs index 571832d675..46ca37522f 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/Validation/PasswordManagerInviteUserValidatorTests.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Models.Business; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Validation.PasswordManager; using Bit.Core.AdminConsole.Utilities.Validation; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs index b16a80d7a2..3c2868d9e3 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeOrganizationUserCommandTests.cs @@ -1,6 +1,6 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v1; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs new file mode 100644 index 0000000000..a74135794f --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUserCommandTests.cs @@ -0,0 +1,215 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; +using Bit.Core.AdminConsole.Utilities.v2.Validation; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +[SutProviderCustomize] +public class RevokeOrganizationUserCommandTests +{ + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithValidUsers_RevokesUsersAndLogsEvents( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + orgUser1.UserId = Guid.NewGuid(); + orgUser2.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser1.Id, orgUser2.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser1, orgUser2]); + SetupValidatorMock(sutProvider, [ + ValidationResultHelpers.Valid(orgUser1), + ValidationResultHelpers.Valid(orgUser2) + ]); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.Result.IsSuccess)); + + await sutProvider.GetDependency() + .Received(1) + .RevokeManyByIdAsync(Arg.Is>(ids => + ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id))); + + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>( + events => events.Count() == 2)); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(orgUser1.UserId!.Value); + + await sutProvider.GetDependency() + .Received(1) + .PushSyncOrgKeysAsync(orgUser2.UserId!.Value); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithSystemUser_LogsEventsWithSystemUserType( + SutProvider sutProvider, + Guid organizationId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(null, false, EventSystemUser.SCIM); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser]); + SetupValidatorMock(sutProvider, [ValidationResultHelpers.Valid(orgUser)]); + + // Act + await sutProvider.Sut.RevokeUsersAsync(request); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync(Arg.Is>( + events => events.All(e => e.Item3 == EventSystemUser.SCIM))); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WithValidationErrors_ReturnsErrorResults( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + + var actingUser = CreateActingUser(actingUserId, false, null); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser1.Id, orgUser2.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser1, orgUser2]); + SetupValidatorMock(sutProvider, [ + ValidationResultHelpers.Invalid(orgUser1, new UserAlreadyRevoked()), + ValidationResultHelpers.Valid(orgUser2) + ]); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + var result1 = results.Single(r => r.Id == orgUser1.Id); + var result2 = results.Single(r => r.Id == orgUser2.Id); + + Assert.True(result1.Result.IsError); + Assert.True(result2.Result.IsSuccess); + + // Only the valid user should be revoked + await sutProvider.GetDependency() + .Received(1) + .RevokeManyByIdAsync(Arg.Is>(ids => + ids.Count() == 1 && ids.Contains(orgUser2.Id))); + } + + [Theory] + [BitAutoData] + public async Task RevokeUsersAsync_WhenPushNotificationFails_ContinuesProcessing( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + + var request = new RevokeOrganizationUsersRequest( + organizationId, + [orgUser.Id], + actingUser); + + SetupRepositoryMocks(sutProvider, [orgUser]); + SetupValidatorMock(sutProvider, [ValidationResultHelpers.Valid(orgUser)]); + + sutProvider.GetDependency() + .PushSyncOrgKeysAsync(orgUser.UserId!.Value) + .Returns(Task.FromException(new Exception("Push notification failed"))); + + // Act + var results = (await sutProvider.Sut.RevokeUsersAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results[0].Result.IsSuccess); + + // Should log warning but continue + sutProvider.GetDependency>() + .Received() + .Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + private static IActingUser CreateActingUser(Guid? userId, bool isOwnerOrProvider, EventSystemUser? systemUserType) => + (userId, systemUserType) switch + { + ({ } id, _) => new StandardUser(id, isOwnerOrProvider), + (null, { } type) => new SystemUser(type) + }; + + private static void SetupRepositoryMocks( + SutProvider sutProvider, + ICollection organizationUsers) + { + sutProvider.GetDependency() + .GetManyAsync(Arg.Any>()) + .Returns(organizationUsers); + } + + private static void SetupValidatorMock( + SutProvider sutProvider, + ICollection> validationResults) + { + sutProvider.GetDependency() + .ValidateAsync(Arg.Any()) + .Returns(validationResults); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs new file mode 100644 index 0000000000..fe5802b00b --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeUser/v2/RevokeOrganizationUsersValidatorTests.cs @@ -0,0 +1,325 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; + +[SutProviderCustomize] +public class RevokeOrganizationUsersValidatorTests +{ + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithValidUsers_ReturnsSuccess( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser2) + { + // Arrange + orgUser1.OrganizationId = orgUser2.OrganizationId = organizationId; + orgUser1.UserId = Guid.NewGuid(); + orgUser2.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [orgUser1, orgUser2], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.IsValid)); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithRevokedUser_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser) + { + // Arrange + revokedUser.OrganizationId = organizationId; + revokedUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [revokedUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenRevokingSelf_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = actingUserId; + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [orgUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenNonOwnerRevokesOwner_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + ownerUser.OrganizationId = organizationId; + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenOwnerRevokesOwner_ReturnsSuccess( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + ownerUser.OrganizationId = organizationId; + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, true, null); + var request = CreateValidationRequest( + organizationId, + [ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMultipleUsers_SomeValid_ReturnsMixedResults( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser validUser, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser) + { + // Arrange + validUser.OrganizationId = revokedUser.OrganizationId = organizationId; + validUser.UserId = Guid.NewGuid(); + revokedUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); + var request = CreateValidationRequest( + organizationId, + [validUser, revokedUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + + var validResult = results.Single(r => r.Request.Id == validUser.Id); + var errorResult = results.Single(r => r.Request.Id == revokedUser.Id); + + Assert.True(validResult.IsValid); + Assert.True(errorResult.IsError); + Assert.IsType(errorResult.AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithSystemUser_DoesNotRequireActingUserId( + SutProvider sutProvider, + Guid organizationId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser) + { + // Arrange + orgUser.OrganizationId = organizationId; + orgUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(null, false, EventSystemUser.SCIM); + var request = CreateValidationRequest( + organizationId, + [orgUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsValid); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WhenRevokingLastOwner_ReturnsErrorForThatUser( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser lastOwner) + { + // Arrange + lastOwner.OrganizationId = organizationId; + lastOwner.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, true, null); // Is an owner + var request = CreateValidationRequest( + organizationId, + [lastOwner], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(false); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Single(results); + Assert.True(results.First().IsError); + Assert.IsType(results.First().AsError); + } + + [Theory] + [BitAutoData] + public async Task ValidateAsync_WithMultipleValidationErrors_ReturnsAllErrors( + SutProvider sutProvider, + Guid organizationId, + Guid actingUserId, + [OrganizationUser(OrganizationUserStatusType.Revoked, OrganizationUserType.User)] OrganizationUser revokedUser, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser ownerUser) + { + // Arrange + revokedUser.OrganizationId = ownerUser.OrganizationId = organizationId; + revokedUser.UserId = Guid.NewGuid(); + ownerUser.UserId = Guid.NewGuid(); + + var actingUser = CreateActingUser(actingUserId, false, null); // Not an owner + var request = CreateValidationRequest( + organizationId, + [revokedUser, ownerUser], + actingUser); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + // Act + var results = (await sutProvider.Sut.ValidateAsync(request)).ToList(); + + // Assert + Assert.Equal(2, results.Count); + Assert.All(results, r => Assert.True(r.IsError)); + + Assert.Contains(results, r => r.AsError is UserAlreadyRevoked); + Assert.Contains(results, r => r.AsError is OnlyOwnersCanRevokeOwners); + } + + private static IActingUser CreateActingUser(Guid? userId, bool isOwnerOrProvider, EventSystemUser? systemUserType) => + (userId, systemUserType) switch + { + ({ } id, _) => new StandardUser(id, isOwnerOrProvider), + (null, { } type) => new SystemUser(type) + }; + + + private static RevokeOrganizationUsersValidationRequest CreateValidationRequest( + Guid organizationId, + ICollection organizationUsers, + IActingUser actingUser) + { + return new RevokeOrganizationUsersValidationRequest( + organizationId, + organizationUsers.Select(u => u.Id).ToList(), + actingUser, + organizationUsers + ); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs index af6b5a17f7..f1c4797de8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/GetOrganizationSubscriptionsToUpdateQueryTests.cs @@ -1,9 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Pricing; using Bit.Core.Repositories; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs index 7e6f5dc9bc..c1fea1455e 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/CloudOrganizationSignUpCommandTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -23,11 +23,12 @@ public class CloudICloudOrganizationSignUpCommandTests { [Theory] [BitAutoData(PlanType.FamiliesAnnually)] + [BitAutoData(PlanType.FamiliesAnnually2025)] public async Task SignUp_PM_Family_Passes(PlanType planType, OrganizationSignup signup, SutProvider sutProvider) { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); signup.AdditionalSeats = 0; signup.PaymentMethodType = PaymentMethodType.Card; @@ -36,7 +37,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.IsFromSecretsManagerTrial = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var result = await sutProvider.Sut.SignUpOrganizationAsync(signup); @@ -65,6 +66,7 @@ public class CloudICloudOrganizationSignUpCommandTests [Theory] [BitAutoData(PlanType.FamiliesAnnually)] + [BitAutoData(PlanType.FamiliesAnnually2025)] public async Task SignUp_AssignsOwnerToDefaultCollection (PlanType planType, OrganizationSignup signup, SutProvider sutProvider) { @@ -75,7 +77,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.UseSecretsManager = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); // Extract orgUserId when created Guid? orgUserId = null; @@ -110,7 +112,7 @@ public class CloudICloudOrganizationSignUpCommandTests { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); signup.UseSecretsManager = true; signup.AdditionalSeats = 15; @@ -121,7 +123,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.IsFromSecretsManagerTrial = false; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var result = await sutProvider.Sut.SignUpOrganizationAsync(signup); @@ -162,7 +164,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.PremiumAccessAddon = false; signup.IsFromProvider = true; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpOrganizationAsync(signup)); Assert.Contains("Organizations with a Managed Service Provider do not support Secrets Manager.", exception.Message); @@ -182,7 +184,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalStorageGb = 0; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -202,7 +204,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalServiceAccounts = 10; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -222,7 +224,7 @@ public class CloudICloudOrganizationSignUpCommandTests signup.AdditionalServiceAccounts = -10; signup.IsFromProvider = false; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpOrganizationAsync(signup)); @@ -242,7 +244,7 @@ public class CloudICloudOrganizationSignUpCommandTests Owner = new User { Id = Guid.NewGuid() } }; - sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(StaticStore.GetPlan(signup.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(signup.Plan).Returns(MockPlans.Get(signup.Plan)); sutProvider.GetDependency() .GetCountByFreeOrganizationAdminUserAsync(signup.Owner.Id) diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs index 881f134b4c..5385b4cdea 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ProviderClientOrganizationSignUpCommandTests.cs @@ -10,7 +10,7 @@ using Bit.Core.Models.Data; using Bit.Core.Models.StaticStore; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -36,7 +36,7 @@ public class ProviderClientOrganizationSignUpCommandTests signup.AdditionalSeats = 15; signup.CollectionName = collectionName; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); @@ -112,7 +112,7 @@ public class ProviderClientOrganizationSignUpCommandTests signup.Plan = PlanType.TeamsMonthly; signup.AdditionalSeats = -5; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); @@ -132,7 +132,7 @@ public class ProviderClientOrganizationSignUpCommandTests { signup.Plan = planType; - var plan = StaticStore.GetPlan(signup.Plan); + var plan = MockPlans.Get(signup.Plan); sutProvider.GetDependency() .GetPlanOrThrow(signup.Plan) .Returns(plan); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs index 55e5698ad4..69f69b1d02 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationSignUp/ResellerClientOrganizationSignUpCommandTests.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; @@ -172,7 +173,7 @@ public class ResellerClientOrganizationSignUpCommandTests private static async Task AssertCleanupIsPerformed(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .CancelAndRecoverChargesAsync(Arg.Any()); await sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs new file mode 100644 index 0000000000..997076e7ef --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/OrganizationUpdateCommandTests.cs @@ -0,0 +1,418 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Update; +using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Organizations; + +[SutProviderCustomize] +public class OrganizationUpdateCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_WhenValidOrganization_UpdatesOrganization( + Guid organizationId, + string name, + string billingEmail, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.GatewayCustomerId = null; // No Stripe customer, but billing update is still called + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = name, + BillingEmail = billingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(name, result.Name); + Assert.Equal(billingEmail.ToLowerInvariant().Trim(), result.BillingEmail); + + await organizationRepository + .Received(1) + .GetByIdAsync(Arg.Is(id => id == organizationId)); + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenOrganizationNotFound_ThrowsNotFoundException( + Guid organizationId, + string name, + string billingEmail, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository + .GetByIdAsync(organizationId) + .Returns((Organization)null); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = name, + BillingEmail = billingEmail + }; + + // Act/Assert + await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateAsync(request)); + } + + [Theory] + [BitAutoData("")] + [BitAutoData((string)null)] + public async Task UpdateAsync_WhenGatewayCustomerIdIsNullOrEmpty_CallsBillingUpdateButHandledGracefully( + string gatewayCustomerId, + Guid organizationId, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.Name = "Old Name"; + organization.GatewayCustomerId = gatewayCustomerId; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = "New Name", + BillingEmail = organization.BillingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal("New Name", result.Name); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenKeysProvided_AndNotAlreadySet_SetsKeys( + Guid organizationId, + string publicKey, + string encryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.PublicKey = null; + organization.PrivateKey = null; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = organization.Name, + BillingEmail = organization.BillingEmail, + Keys = new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: encryptedPrivateKey, + publicKey: publicKey) + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(publicKey, result.PublicKey); + Assert.Equal(encryptedPrivateKey, result.PrivateKey); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenKeysProvided_AndAlreadySet_DoesNotOverwriteKeys( + Guid organizationId, + string newPublicKey, + string newEncryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + + organization.Id = organizationId; + var existingPublicKey = organization.PublicKey; + var existingPrivateKey = organization.PrivateKey; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = organization.Name, + BillingEmail = organization.BillingEmail, + Keys = new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: newEncryptedPrivateKey, + publicKey: newPublicKey) + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(existingPublicKey, result.PublicKey); + Assert.Equal(existingPrivateKey, result.PrivateKey); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UpdatingNameOnly_UpdatesNameAndNotBillingEmail( + Guid organizationId, + string newName, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.Name = "Old Name"; + var originalBillingEmail = organization.BillingEmail; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = newName, + BillingEmail = null + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(newName, result.Name); + Assert.Equal(originalBillingEmail, result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_UpdatingBillingEmailOnly_UpdatesBillingEmailAndNotName( + Guid organizationId, + string newBillingEmail, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + organization.BillingEmail = "old@example.com"; + var originalName = organization.Name; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = null, + BillingEmail = newBillingEmail + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(originalName, result.Name); + Assert.Equal(newBillingEmail.ToLowerInvariant().Trim(), result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .Received(1) + .UpdateOrganizationNameAndEmail(result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WhenNoChanges_PreservesBothFields( + Guid organizationId, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationRepository = sutProvider.GetDependency(); + var organizationService = sutProvider.GetDependency(); + var organizationBillingService = sutProvider.GetDependency(); + + organization.Id = organizationId; + var originalName = organization.Name; + var originalBillingEmail = organization.BillingEmail; + + organizationRepository + .GetByIdAsync(organizationId) + .Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = null, + BillingEmail = null + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.Id); + Assert.Equal(originalName, result.Name); + Assert.Equal(originalBillingEmail, result.BillingEmail); + + await organizationService + .Received(1) + .ReplaceAndUpdateCacheAsync( + result, + EventType.Organization_Updated); + await organizationBillingService + .DidNotReceiveWithAnyArgs() + .UpdateOrganizationNameAndEmail(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_SelfHosted_OnlyUpdatesKeysNotOrganizationDetails( + Guid organizationId, + string newName, + string newBillingEmail, + string publicKey, + string encryptedPrivateKey, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var organizationBillingService = sutProvider.GetDependency(); + var globalSettings = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + + globalSettings.SelfHosted.Returns(true); + + organization.Id = organizationId; + organization.Name = "Original Name"; + organization.BillingEmail = "original@example.com"; + organization.PublicKey = null; + organization.PrivateKey = null; + + organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var request = new OrganizationUpdateRequest + { + OrganizationId = organizationId, + Name = newName, // Should be ignored + BillingEmail = newBillingEmail, // Should be ignored + Keys = new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: encryptedPrivateKey, + publicKey: publicKey) + }; + + // Act + var result = await sutProvider.Sut.UpdateAsync(request); + + // Assert + Assert.Equal("Original Name", result.Name); // Not changed + Assert.Equal("original@example.com", result.BillingEmail); // Not changed + Assert.Equal(publicKey, result.PublicKey); // Changed + Assert.Equal(encryptedPrivateKey, result.PrivateKey); // Changed + + await organizationBillingService + .DidNotReceiveWithAnyArgs() + .UpdateOrganizationNameAndEmail(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs index 37a5627919..47872cc6ab 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Organizations/UpdateOrganizationSubscriptionCommandTests.cs @@ -2,10 +2,10 @@ using Bit.Core.AdminConsole.Models.Data.Organizations; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; +using Bit.Core.Billing.Services; using Bit.Core.Models.StaticStore; using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -28,7 +28,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() .AdjustSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); @@ -53,7 +53,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), @@ -81,7 +81,7 @@ public class UpdateOrganizationSubscriptionCommandTests OrganizationSubscriptionUpdate[] subscriptionsToUpdate = [new() { Organization = organization, Plan = new Enterprise2023Plan(true) }]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == organization.Id), Arg.Is(x => x.Type == organization.PlanType), @@ -115,7 +115,7 @@ public class UpdateOrganizationSubscriptionCommandTests new() { Organization = failedOrganization, Plan = new Enterprise2023Plan(true) } ]; - sutProvider.GetDependency() + sutProvider.GetDependency() .AdjustSeatsAsync( Arg.Is(x => x.Id == failedOrganization.Id), Arg.Is(x => x.Type == failedOrganization.PlanType), @@ -124,7 +124,7 @@ public class UpdateOrganizationSubscriptionCommandTests // Act await sutProvider.Sut.UpdateOrganizationSubscriptionAsync(subscriptionsToUpdate); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) .AdjustSeatsAsync( Arg.Is(x => x.Id == successfulOrganization.Id), diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs new file mode 100644 index 0000000000..f2e6adbfa9 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/Enforcement/AutoConfirm/AutomaticUserConfirmationPolicyEnforcementValidatorTests.cs @@ -0,0 +1,306 @@ +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.AutoConfirmUser; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; + +[SutProviderCustomize] +public class AutomaticUserConfirmationPolicyEnforcementValidatorTests +{ + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledAndUserIsProviderMember_ReturnsProviderUsersCannotJoinError( + SutProvider sutProvider, + OrganizationUser organizationUser, + ProviderUser providerUser, + User user) + { + // Arrange + organizationUser.UserId = providerUser.UserId = user.Id; + + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledOnOtherOrganization_ReturnsOtherOrganizationDoesNotAllowOtherMembershipError( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrganizationUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrganizationUser.UserId = user.Id; + + var otherOrgId = Guid.NewGuid(); + var policyDetails = new PolicyDetails + { + OrganizationId = otherOrgId, // Different from organizationUser.OrganizationId + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrganizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledUserIsAMemberOfAnotherOrgReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledUserIsAMemberOfAnotherOrg_ReturnsCannotBeMemberOfAnotherOrgError( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + otherOrgUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation + }; + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyEnabledAndChecksConditionsInCorrectOrder_ReturnsFirstFailure( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + ProviderUser providerUser, + User user) + { + // Arrange + var policyDetails = new PolicyDetails + { + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation, + OrganizationUserId = organizationUser.Id + }; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser, otherOrgUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([policyDetails])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsError); + Assert.IsType(result.AsError); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyIsEnabledNoOtherOrganizationsAndNotAProvider_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + User user) + { + // Arrange + organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([ + new PolicyDetails + { + OrganizationUserId = organizationUser.Id, + OrganizationId = organizationUser.OrganizationId, + PolicyType = PolicyType.AutomaticUserConfirmation, + } + ])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledForCurrentAndOtherOrg_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + User user) + { + // Arrange + otherOrgUser.UserId = organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } + + [Theory] + [BitAutoData] + public async Task IsCompliantAsync_WithPolicyDisabledForCurrentAndOtherOrgAndIsProvider_ReturnsValid( + SutProvider sutProvider, + OrganizationUser organizationUser, + OrganizationUser otherOrgUser, + ProviderUser providerUser, + User user) + { + // Arrange + providerUser.UserId = otherOrgUser.UserId = organizationUser.UserId = user.Id; + + var request = new AutomaticUserConfirmationPolicyEnforcementRequest( + organizationUser.OrganizationId, + [organizationUser], + user); + + sutProvider.GetDependency() + .GetAsync(user.Id) + .Returns(new AutomaticUserConfirmationPolicyRequirement([])); + + sutProvider.GetDependency() + .GetManyByUserAsync(user.Id) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.IsCompliantAsync(request); + + // Assert + Assert.True(result.IsValid); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyEventHandlerHandlerFactoryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyEventHandlerHandlerFactoryTests.cs new file mode 100644 index 0000000000..61d24735b6 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyEventHandlerHandlerFactoryTests.cs @@ -0,0 +1,124 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using OneOf.Types; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class PolicyEventHandlerHandlerFactoryTests +{ + [Fact] + public void GetHandler_ReturnsHandler_WhenHandlerExists() + { + // Arrange + var expectedHandler = new FakeSingleOrgDependencyEvent(); + var factory = new PolicyEventHandlerHandlerFactory([expectedHandler]); + + // Act + var result = factory.GetHandler(PolicyType.SingleOrg); + + // Assert + Assert.True(result.IsT0); + Assert.Equal(expectedHandler, result.AsT0); + } + + [Fact] + public void GetHandler_ReturnsNone_WhenHandlerDoesNotExist() + { + // Arrange + var factory = new PolicyEventHandlerHandlerFactory([new FakeSingleOrgDependencyEvent()]); + + // Act + var result = factory.GetHandler(PolicyType.RequireSso); + + // Assert + Assert.True(result.IsT1); + Assert.IsType(result.AsT1); + } + + [Fact] + public void GetHandler_ReturnsNone_WhenHandlerTypeDoesNotMatch() + { + // Arrange + var factory = new PolicyEventHandlerHandlerFactory([new FakeSingleOrgDependencyEvent()]); + + // Act + var result = factory.GetHandler(PolicyType.SingleOrg); + + // Assert + Assert.True(result.IsT1); + Assert.IsType(result.AsT1); + } + + [Fact] + public void GetHandler_ReturnsCorrectHandler_WhenMultipleHandlerTypesExist() + { + // Arrange + var dependencyEvent = new FakeSingleOrgDependencyEvent(); + var validationEvent = new FakeSingleOrgValidationEvent(); + var factory = new PolicyEventHandlerHandlerFactory([dependencyEvent, validationEvent]); + + // Act + var dependencyResult = factory.GetHandler(PolicyType.SingleOrg); + var validationResult = factory.GetHandler(PolicyType.SingleOrg); + + // Assert + Assert.True(dependencyResult.IsT0); + Assert.Equal(dependencyEvent, dependencyResult.AsT0); + + Assert.True(validationResult.IsT0); + Assert.Equal(validationEvent, validationResult.AsT0); + } + + [Fact] + public void GetHandler_ReturnsCorrectHandler_WhenMultiplePolicyTypesExist() + { + // Arrange + var singleOrgEvent = new FakeSingleOrgDependencyEvent(); + var requireSsoEvent = new FakeRequireSsoDependencyEvent(); + var factory = new PolicyEventHandlerHandlerFactory([singleOrgEvent, requireSsoEvent]); + + // Act + var singleOrgResult = factory.GetHandler(PolicyType.SingleOrg); + var requireSsoResult = factory.GetHandler(PolicyType.RequireSso); + + // Assert + Assert.True(singleOrgResult.IsT0); + Assert.Equal(singleOrgEvent, singleOrgResult.AsT0); + + Assert.True(requireSsoResult.IsT0); + Assert.Equal(requireSsoEvent, requireSsoResult.AsT0); + } + + [Fact] + public void GetHandler_Throws_WhenDuplicateHandlersExist() + { + // Arrange + var factory = new PolicyEventHandlerHandlerFactory([ + new FakeSingleOrgDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + // Act & Assert + var exception = Assert.Throws(() => + factory.GetHandler(PolicyType.SingleOrg)); + + Assert.Contains("Multiple IPolicyUpdateEvent handlers of type IEnforceDependentPoliciesEvent found for PolicyType SingleOrg", exception.Message); + Assert.Contains("Expected one IEnforceDependentPoliciesEvent handler per PolicyType", exception.Message); + } + + [Fact] + public void GetHandler_ReturnsNone_WhenNoHandlersProvided() + { + // Arrange + var factory = new PolicyEventHandlerHandlerFactory([]); + + // Act + var result = factory.GetHandler(PolicyType.SingleOrg); + + // Assert + Assert.True(result.IsT1); + Assert.IsType(result.AsT1); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs index 8c25f70454..9115ae5ba1 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyRequirementQueryTests.cs @@ -14,10 +14,12 @@ public class PolicyRequirementQueryTests [Theory, BitAutoData] public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId) { - var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg }; - var otherPolicy = new PolicyDetails { PolicyType = PolicyType.RequireSso }; + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId }; var policyRepository = Substitute.For(); - policyRepository.GetPolicyDetailsByUserId(userId).Returns([otherPolicy, thisPolicy]); + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Is>(ids => ids.Contains(userId)), PolicyType.SingleOrg) + .Returns([otherPolicy, thisPolicy]); var factory = new TestPolicyRequirementFactory(_ => true); var sut = new PolicyRequirementQuery(policyRepository, [factory]); @@ -33,9 +35,11 @@ public class PolicyRequirementQueryTests { // Arrange policies var policyRepository = Substitute.For(); - var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg }; - var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg }; - policyRepository.GetPolicyDetailsByUserId(userId).Returns([thisPolicy, otherPolicy]); + var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId }; + var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId }; + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Is>(ids => ids.Contains(userId)), PolicyType.SingleOrg) + .Returns([thisPolicy, otherPolicy]); // Arrange a substitute Enforce function so that we can inspect the received calls var callback = Substitute.For>(); @@ -70,7 +74,9 @@ public class PolicyRequirementQueryTests public async Task GetAsync_HandlesNoPolicies(Guid userId) { var policyRepository = Substitute.For(); - policyRepository.GetPolicyDetailsByUserId(userId).Returns([]); + policyRepository.GetPolicyDetailsByUserIdsAndPolicyType( + Arg.Is>(ids => ids.Contains(userId)), PolicyType.SingleOrg) + .Returns([]); var factory = new TestPolicyRequirementFactory(x => x.IsProvider); var sut = new PolicyRequirementQuery(policyRepository, [factory]); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEventFixtures.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEventFixtures.cs new file mode 100644 index 0000000000..4c5b23d6e1 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEventFixtures.cs @@ -0,0 +1,37 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using NSubstitute; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class FakeSingleOrgDependencyEvent : IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.SingleOrg; + public IEnumerable RequiredPolicies => []; +} + +public class FakeRequireSsoDependencyEvent : IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.RequireSso; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; +} + +public class FakeVaultTimeoutDependencyEvent : IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.MaximumVaultTimeout; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; +} + +public class FakeSingleOrgValidationEvent : IPolicyValidationEvent +{ + public PolicyType Type => PolicyType.SingleOrg; + + public readonly Func> ValidateAsyncMock = Substitute.For>>(); + + public Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return ValidateAsyncMock(policyRequest, currentPolicy); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs new file mode 100644 index 0000000000..3c9fd9a9e9 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandlerTests.cs @@ -0,0 +1,406 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +[SutProviderCustomize] +public class AutomaticUserConfirmationPolicyEventHandlerTests +{ + [Theory, BitAutoData] + public void RequiredPolicies_IncludesSingleOrg( + SutProvider sutProvider) + { + // Act + var requiredPolicies = sutProvider.Sut.RequiredPolicies; + + // Assert + Assert.Contains(PolicyType.SingleOrg, requiredPolicies); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_UsersNotCompliantWithSingleOrg_ReturnsError( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + Guid nonCompliantUserId, + SutProvider sutProvider) + { + // Arrange + var orgUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantUserId, + Email = "user@example.com" + }; + + var otherOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([otherOrgUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_UserWithInvitedStatusInOtherOrg_ValidationPasses( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + Guid userId, + SutProvider sutProvider) + { + // Arrange + var orgUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = userId, + }; + + var otherOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + UserId = null, // invited users do not have a user id + Status = OrganizationUserStatusType.Invited, + Email = orgUser.Email + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([otherOrgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_ProviderUsersExist_ReturnsError( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + Guid userId, + SutProvider sutProvider) + { + // Arrange + var orgUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = userId + }; + + var providerUser = new ProviderUser + { + Id = Guid.NewGuid(), + ProviderId = Guid.NewGuid(), + UserId = userId, + Status = ProviderUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([providerUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Contains("Provider user type", result, StringComparison.OrdinalIgnoreCase); + } + + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_AllValidationsPassed_ReturnsEmptyString( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + var orgUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = Guid.NewGuid() + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_PolicyAlreadyEnabled_ReturnsEmptyString( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy, + SutProvider sutProvider) + { + // Arrange + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, currentPolicy); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + + await sutProvider.GetDependency() + .DidNotReceive() + .GetManyDetailsByOrganizationAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_ReturnsEmptyString( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.AutomaticUserConfirmation)] Policy currentPolicy, + SutProvider sutProvider) + { + // Arrange + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, currentPolicy); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + await sutProvider.GetDependency() + .DidNotReceive() + .GetManyDetailsByOrganizationAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_IncludesOwnersAndAdmins_InComplianceCheck( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + Guid nonCompliantOwnerId, + SutProvider sutProvider) + { + // Arrange + var ownerUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.Owner, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantOwnerId, + }; + + var otherOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + UserId = nonCompliantOwnerId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([ownerUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([otherOrgUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_InvitedUsersExcluded_FromComplianceCheck( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + var invitedUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Invited, + UserId = Guid.NewGuid(), + Email = "invited@example.com" + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([invitedUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_RevokedUsersIncluded_InComplianceCheck( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + var revokedUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Revoked, + UserId = Guid.NewGuid(), + }; + + var additionalOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Revoked, + UserId = revokedUser.UserId, + }; + + var orgUserRepository = sutProvider.GetDependency(); + + orgUserRepository + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([revokedUser]); + + orgUserRepository.GetManyByManyUsersAsync(Arg.Any>()) + .Returns([additionalOrgUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_AcceptedUsersIncluded_InComplianceCheck( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + Guid nonCompliantUserId, + SutProvider sutProvider) + { + // Arrange + var acceptedUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = policyUpdate.OrganizationId, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Accepted, + UserId = nonCompliantUserId, + }; + + var otherOrgUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([acceptedUser]); + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Any>()) + .Returns([otherOrgUser]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Contains("compliant with the Single organization policy", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_CallsValidateWithPolicyUpdate( + [PolicyUpdate(PolicyType.AutomaticUserConfirmation)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + var savePolicyModel = new SavePolicyModel(policyUpdate); + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([]); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs new file mode 100644 index 0000000000..e317a5886e --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/BlockClaimedDomainAccountCreationPolicyValidatorTests.cs @@ -0,0 +1,189 @@ +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +[SutProviderCustomize] +public class BlockClaimedDomainAccountCreationPolicyValidatorTests +{ + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_NoVerifiedDomains_ValidationError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Equal("You must claim at least one domain to turn on this policy", result); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_EnablingPolicy_HasVerifiedDomains_Success( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_DisablingPolicy_NoValidation( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, false)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_EnablingPolicy_NoVerifiedDomains_ValidationError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(false); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.Equal("You must claim at least one domain to turn on this policy", result); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_EnablingPolicy_HasVerifiedDomains_Success( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(true); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_NoValidation( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, false)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + + // Act + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, null); + + // Assert + Assert.True(string.IsNullOrEmpty(result)); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_FeatureFlagDisabled_ReturnsError( + [PolicyUpdate(PolicyType.BlockClaimedDomainAccountCreation, true)] PolicyUpdate policyUpdate, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(false); + + // Act + var result = await sutProvider.Sut.ValidateAsync(policyUpdate, null); + + // Assert + Assert.Equal("This feature is not enabled", result); + await sutProvider.GetDependency() + .DidNotReceive() + .HasVerifiedDomainsAsync(Arg.Any()); + } + + [Fact] + public void Type_ReturnsBlockClaimedDomainAccountCreation() + { + // Arrange + var validator = new BlockClaimedDomainAccountCreationPolicyValidator(null, null); + + // Act & Assert + Assert.Equal(PolicyType.BlockClaimedDomainAccountCreation, validator.Type); + } + + [Fact] + public void RequiredPolicies_ReturnsEmpty() + { + // Arrange + var validator = new BlockClaimedDomainAccountCreationPolicyValidator(null, null); + + // Act + var requiredPolicies = validator.RequiredPolicies.ToList(); + + // Assert + Assert.Empty(requiredPolicies); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidatorTests.cs index 0aa670297b..525169a1fb 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidatorTests.cs @@ -72,4 +72,65 @@ public class FreeFamiliesForEnterprisePolicyValidatorTests organizationSponsorships[0].SponsoredOrganizationId.ToString(), organization.Name); } + + [Theory, BitAutoData] + public async Task ExecutePreUpsertSideEffectAsync_DoesNotNotifyUserWhenPolicyDisabled( + Organization organization, + List organizationSponsorships, + [PolicyUpdate(PolicyType.FreeFamiliesSponsorshipPolicy)] PolicyUpdate policyUpdate, + [Policy(PolicyType.FreeFamiliesSponsorshipPolicy, true)] Policy policy, + SutProvider sutProvider) + { + policy.Enabled = true; + policyUpdate.Enabled = false; + + sutProvider.GetDependency() + .GetByIdAsync(policyUpdate.OrganizationId) + .Returns(organization); + + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId) + .Returns(organizationSponsorships); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendFamiliesForEnterpriseRemoveSponsorshipsEmailAsync(default, default, default, default); + } + + [Theory, BitAutoData] + public async Task ExecutePreUpsertSideEffectAsync_DoesNotifyUserWhenPolicyEnabled( + Organization organization, + List organizationSponsorships, + [PolicyUpdate(PolicyType.FreeFamiliesSponsorshipPolicy)] PolicyUpdate policyUpdate, + [Policy(PolicyType.FreeFamiliesSponsorshipPolicy, false)] Policy policy, + SutProvider sutProvider) + { + policy.Enabled = false; + policyUpdate.Enabled = true; + + sutProvider.GetDependency() + .GetByIdAsync(policyUpdate.OrganizationId) + .Returns(organization); + + sutProvider.GetDependency() + .GetManyBySponsoringOrganizationAsync(policyUpdate.OrganizationId) + .Returns(organizationSponsorships); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy); + + var offerAcceptanceDate = organizationSponsorships[0].ValidUntil!.Value.AddDays(-7).ToString("MM/dd/yyyy"); + await sutProvider.GetDependency() + .Received(1) + .SendFamiliesForEnterpriseRemoveSponsorshipsEmailAsync( + organizationSponsorships[0].FriendlyName, + offerAcceptanceDate, + organizationSponsorships[0].SponsoredOrganizationId.ToString(), + organization.Name); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs index a39382382b..e6677c8a23 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs @@ -32,7 +32,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) .Returns(false); - var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); // Act await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -58,7 +58,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) .Returns(true); - var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); // Act await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -84,7 +84,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) .Returns(true); - var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); // Act await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -110,7 +110,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests var collectionRepository = Substitute.For(); var sut = ArrangeSut(factory, policyRepository, collectionRepository); - var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); // Act await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -199,7 +199,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests var collectionRepository = Substitute.For(); var sut = ArrangeSut(factory, policyRepository, collectionRepository); - var policyRequest = new SavePolicyModel(policyUpdate, null, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); // Act await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -238,7 +238,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) .Returns(true); - var policyRequest = new SavePolicyModel(policyUpdate, null, metadata); + var policyRequest = new SavePolicyModel(policyUpdate, metadata); // Act await sutProvider.Sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); @@ -274,4 +274,176 @@ public class OrganizationDataOwnershipPolicyValidatorTests return sut; } + [Theory, BitAutoData] + public async Task ExecutePostUpsertSideEffectAsync_FeatureFlagDisabled_DoesNothing( + [PolicyUpdate(PolicyType.OrganizationDataOwnership, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.OrganizationDataOwnership, false)] Policy postUpdatedPolicy, + [Policy(PolicyType.OrganizationDataOwnership, false)] Policy previousPolicyState, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) + .Returns(false); + + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertDefaultCollectionsAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task ExecutePostUpsertSideEffectAsync_PolicyAlreadyEnabled_DoesNothing( + [PolicyUpdate(PolicyType.OrganizationDataOwnership, true)] PolicyUpdate policyUpdate, + [Policy(PolicyType.OrganizationDataOwnership, true)] Policy postUpdatedPolicy, + [Policy(PolicyType.OrganizationDataOwnership, true)] Policy previousPolicyState, + SutProvider sutProvider) + { + // Arrange + postUpdatedPolicy.OrganizationId = policyUpdate.OrganizationId; + previousPolicyState.OrganizationId = policyUpdate.OrganizationId; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) + .Returns(true); + + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertDefaultCollectionsAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task ExecutePostUpsertSideEffectAsync_PolicyBeingDisabled_DoesNothing( + [PolicyUpdate(PolicyType.OrganizationDataOwnership, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.OrganizationDataOwnership, false)] Policy postUpdatedPolicy, + [Policy(PolicyType.OrganizationDataOwnership)] Policy previousPolicyState, + SutProvider sutProvider) + { + // Arrange + previousPolicyState.OrganizationId = policyUpdate.OrganizationId; + postUpdatedPolicy.OrganizationId = policyUpdate.OrganizationId; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) + .Returns(true); + + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertDefaultCollectionsAsync(default, default, default); + } + + [Theory, BitAutoData] + public async Task ExecutePostUpsertSideEffectAsync_WhenNoUsersExist_DoNothing( + [PolicyUpdate(PolicyType.OrganizationDataOwnership, true)] PolicyUpdate policyUpdate, + [Policy(PolicyType.OrganizationDataOwnership, true)] Policy postUpdatedPolicy, + [Policy(PolicyType.OrganizationDataOwnership, false)] Policy previousPolicyState, + OrganizationDataOwnershipPolicyRequirementFactory factory) + { + // Arrange + postUpdatedPolicy.OrganizationId = policyUpdate.OrganizationId; + previousPolicyState.OrganizationId = policyUpdate.OrganizationId; + + var policyRepository = ArrangePolicyRepository([]); + var collectionRepository = Substitute.For(); + + var sut = ArrangeSut(factory, policyRepository, collectionRepository); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await collectionRepository + .DidNotReceiveWithAnyArgs() + .UpsertDefaultCollectionsAsync( + default, + default, + default); + + await policyRepository + .Received(1) + .GetPolicyDetailsByOrganizationIdAsync( + policyUpdate.OrganizationId, + PolicyType.OrganizationDataOwnership); + } + + [Theory] + [BitMemberAutoData(nameof(ShouldUpsertDefaultCollectionsTestCases))] + public async Task ExecutePostUpsertSideEffectAsync_WithRequirements_ShouldUpsertDefaultCollections( + Policy postUpdatedPolicy, + Policy? previousPolicyState, + [PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate, + [OrganizationPolicyDetails(PolicyType.OrganizationDataOwnership)] IEnumerable orgPolicyDetails, + OrganizationDataOwnershipPolicyRequirementFactory factory) + { + // Arrange + var orgPolicyDetailsList = orgPolicyDetails.ToList(); + foreach (var policyDetail in orgPolicyDetailsList) + { + policyDetail.OrganizationId = policyUpdate.OrganizationId; + } + + var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList); + var collectionRepository = Substitute.For(); + + var sut = ArrangeSut(factory, policyRepository, collectionRepository); + var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName)); + + // Act + await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await collectionRepository + .Received(1) + .UpsertDefaultCollectionsAsync( + policyUpdate.OrganizationId, + Arg.Is>(ids => ids.Count() == 3), + _defaultUserCollectionName); + } + + [Theory] + [BitMemberAutoData(nameof(WhenDefaultCollectionsDoesNotExistTestCases))] + public async Task ExecutePostUpsertSideEffectAsync_WhenDefaultCollectionNameIsInvalid_DoesNothing( + IPolicyMetadataModel metadata, + [PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate, + [Policy(PolicyType.OrganizationDataOwnership, true)] Policy postUpdatedPolicy, + [Policy(PolicyType.OrganizationDataOwnership, false)] Policy previousPolicyState, + SutProvider sutProvider) + { + // Arrange + postUpdatedPolicy.OrganizationId = policyUpdate.OrganizationId; + previousPolicyState.OrganizationId = policyUpdate.OrganizationId; + policyUpdate.Enabled = true; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.CreateDefaultLocation) + .Returns(true); + + var policyRequest = new SavePolicyModel(policyUpdate, metadata); + + // Act + await sutProvider.Sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); + + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertDefaultCollectionsAsync(default, default, default); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs index d3af765f79..6fc6b85668 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidatorTests.cs @@ -72,4 +72,66 @@ public class RequireSsoPolicyValidatorTests var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); Assert.True(string.IsNullOrEmpty(result)); } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_KeyConnectorEnabled_ValidationError( + [PolicyUpdate(PolicyType.RequireSso, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.RequireSso)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_TdeEnabled_ValidationError( + [PolicyUpdate(PolicyType.RequireSso, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.RequireSso)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.Contains("Trusted device encryption is on", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_DecryptionOptionsNotEnabled_Success( + [PolicyUpdate(PolicyType.RequireSso, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.RequireSso)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.True(string.IsNullOrEmpty(result)); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs index 83939406b5..b3d328c5ab 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidatorTests.cs @@ -68,4 +68,59 @@ public class ResetPasswordPolicyValidatorTests var result = await sutProvider.Sut.ValidateAsync(policyUpdate, policy); Assert.True(string.IsNullOrEmpty(result)); } + + [Theory] + [BitAutoData(true, false)] + [BitAutoData(false, true)] + [BitAutoData(false, false)] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_TdeEnabled_ValidationError( + bool policyEnabled, + bool autoEnrollEnabled, + [PolicyUpdate(PolicyType.ResetPassword)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policyUpdate.Enabled = policyEnabled; + policyUpdate.SetDataModel(new ResetPasswordDataModel + { + AutoEnrollEnabled = autoEnrollEnabled + }); + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.Contains("Trusted device encryption is on and requires this policy.", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_TdeNotEnabled_Success( + [PolicyUpdate(PolicyType.ResetPassword, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.ResetPassword)] Policy policy, + SutProvider sutProvider) + { + policyUpdate.SetDataModel(new ResetPasswordDataModel + { + AutoEnrollEnabled = false + }); + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.True(string.IsNullOrEmpty(result)); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs index e982a67e46..7c58d46636 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; @@ -145,4 +146,135 @@ public class SingleOrgPolicyValidatorTests .Received(1) .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), nonCompliantUser.Email); } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_KeyConnectorEnabled_ValidationError( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = true }; + ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }); + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.Contains("Key Connector is enabled", result, StringComparison.OrdinalIgnoreCase); + } + + [Theory, BitAutoData] + public async Task ValidateAsync_WithSavePolicyModel_DisablingPolicy_KeyConnectorNotEnabled_Success( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = policyUpdate.OrganizationId; + + var ssoConfig = new SsoConfig { Enabled = false }; + + sutProvider.GetDependency() + .GetByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns(ssoConfig); + + sutProvider.GetDependency() + .HasVerifiedDomainsAsync(policyUpdate.OrganizationId) + .Returns(false); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var result = await sutProvider.Sut.ValidateAsync(savePolicyModel, policy); + Assert.True(string.IsNullOrEmpty(result)); + } + + [Theory, BitAutoData] + public async Task ExecutePreUpsertSideEffectAsync_RevokesNonCompliantUsers( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy policy, + Guid savingUserId, + Guid nonCompliantUserId, + Organization organization, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + + var compliantUser1 = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user1@example.com" + }; + + var compliantUser2 = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user2@example.com" + }; + + var nonCompliantUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantUserId, + Email = "user3@example.com" + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([compliantUser1, compliantUser2, nonCompliantUser]); + + var otherOrganizationUser = new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = new Guid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Is>(ids => ids.Contains(nonCompliantUserId))) + .Returns([otherOrganizationUser]); + + sutProvider.GetDependency().UserId.Returns(savingUserId); + sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId).Returns(organization); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy); + + await sutProvider.GetDependency() + .Received(1) + .RevokeNonCompliantOrganizationUsersAsync( + Arg.Is(r => + r.OrganizationId == organization.Id && + r.OrganizationUsers.Count() == 1 && + r.OrganizationUsers.First().Id == nonCompliantUser.Id)); + await sutProvider.GetDependency() + .DidNotReceive() + .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), compliantUser1.Email); + await sutProvider.GetDependency() + .DidNotReceive() + .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), compliantUser2.Email); + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), nonCompliantUser.Email); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs index 7b344d3b29..7d5aaf8d21 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs @@ -136,4 +136,124 @@ public class TwoFactorAuthenticationPolicyValidatorTests .SendOrganizationUserRevokedForTwoFactorPolicyEmailAsync(organization.DisplayName(), compliantUser.Email); } + + [Theory, BitAutoData] + public async Task ExecutePreUpsertSideEffectAsync_GivenNonCompliantUsersWithoutMasterPassword_Throws( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var orgUserDetailUserWithout2Fa = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUserDetailUserWithout2Fa]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (orgUserDetailUserWithout2Fa, false), + }); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy)); + + Assert.Equal(TwoFactorAuthenticationPolicyValidator.NonCompliantMembersWillLoseAccessMessage, exception.Message); + } + + [Theory, BitAutoData] + public async Task ExecutePreUpsertSideEffectAsync_RevokesOnlyNonCompliantUsers( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + // Arrange + policy.OrganizationId = policyUpdate.OrganizationId; + organization.Id = policyUpdate.OrganizationId; + + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + var nonCompliantUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + + var compliantUser = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + Email = "user4@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([nonCompliantUser, compliantUser]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (nonCompliantUser, false), + (compliantUser, true) + }); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + // Act + await sutProvider.Sut.ExecutePreUpsertSideEffectAsync(savePolicyModel, policy); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .RevokeNonCompliantOrganizationUsersAsync(Arg.Is(req => + req.OrganizationId == policyUpdate.OrganizationId && + req.OrganizationUsers.SequenceEqual(new[] { nonCompliantUser }) + )); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserRevokedForTwoFactorPolicyEmailAsync(organization.DisplayName(), + nonCompliantUser.Email); + + // Did not send out an email for compliantUser + await sutProvider.GetDependency() + .Received(0) + .SendOrganizationUserRevokedForTwoFactorPolicyEmailAsync(organization.DisplayName(), + compliantUser.Email); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidatorTests.cs new file mode 100644 index 0000000000..7059305ac8 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidatorTests.cs @@ -0,0 +1,28 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class UriMatchDefaultPolicyValidatorTests +{ + private readonly UriMatchDefaultPolicyValidator _validator = new(); + + [Fact] + // Test that the Type property returns the correct PolicyType for this validator + public void Type_ReturnsUriMatchDefaults() + { + Assert.Equal(PolicyType.UriMatchDefaults, _validator.Type); + } + + [Fact] + // Test that the RequiredPolicies property returns exactly one policy (SingleOrg) as a prerequisite + // for enabling the UriMatchDefaults policy, ensuring proper policy dependency enforcement + public void RequiredPolicies_ReturnsSingleOrgPolicy() + { + var requiredPolicies = _validator.RequiredPolicies.ToList(); + + Assert.Single(requiredPolicies); + Assert.Contains(PolicyType.SingleOrg, requiredPolicies); + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs index 6b85760794..275466a9bd 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs @@ -6,8 +6,11 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models; using Bit.Core.Models.Data.Organizations; +using Bit.Core.Platform.Push; using Bit.Core.Services; using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -95,7 +98,8 @@ public class SavePolicyCommandTests Substitute.For(), [new FakeSingleOrgPolicyValidator(), new FakeSingleOrgPolicyValidator()], Substitute.For(), - Substitute.For())); + Substitute.For(), + Substitute.For())); Assert.Contains("Duplicate PolicyValidator for SingleOrg policy", exception.Message); } @@ -288,7 +292,7 @@ public class SavePolicyCommandTests { // Arrange var sutProvider = SutProviderFactory(); - var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + var savePolicyModel = new SavePolicyModel(policyUpdate); currentPolicy.OrganizationId = policyUpdate.OrganizationId; sutProvider.GetDependency() @@ -332,7 +336,7 @@ public class SavePolicyCommandTests var sutProvider = SutProviderFactory(); - var savePolicyModel = new SavePolicyModel(policyUpdate, null, new EmptyMetadataModel()); + var savePolicyModel = new SavePolicyModel(policyUpdate); sutProvider.GetDependency() .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) @@ -360,6 +364,103 @@ public class SavePolicyCommandTests .ExecuteSideEffectsAsync(default!, default!, default!); } + [Theory, BitAutoData] + public async Task VNextSaveAsync_SendsPushNotification( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + // Arrange + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + var savePolicyModel = new SavePolicyModel(policyUpdate); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + // Act + var result = await sutProvider.Sut.VNextSaveAsync(savePolicyModel); + + // Assert + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_SendsPushNotification([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_SendsPushNotificationWithUpdatedPolicy( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + var fakePolicyValidator = new FakeSingleOrgPolicyValidator(); + fakePolicyValidator.ValidateAsyncMock(policyUpdate, null).Returns(""); + var sutProvider = SutProviderFactory([fakePolicyValidator]); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + var result = await sutProvider.Sut.SaveAsync(policyUpdate); + + await sutProvider.GetDependency().Received(1) + .PushAsync(Arg.Is>(p => + p.Type == PushType.PolicyChanged && + p.Target == NotificationTarget.Organization && + p.TargetId == policyUpdate.OrganizationId && + p.ExcludeCurrentContext == false && + p.Payload.OrganizationId == policyUpdate.OrganizationId && + p.Payload.Policy.Id == result.Id && + p.Payload.Policy.Type == policyUpdate.Type && + p.Payload.Policy.Enabled == policyUpdate.Enabled && + p.Payload.Policy.Data == policyUpdate.Data)); + } + /// /// Returns a new SutProvider with the PolicyValidators registered in the Sut. /// diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/VNextSavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/VNextSavePolicyCommandTests.cs new file mode 100644 index 0000000000..a7dc0402a2 --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/VNextSavePolicyCommandTests.cs @@ -0,0 +1,457 @@ +#nullable enable + +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Time.Testing; +using NSubstitute; +using OneOf.Types; +using Xunit; +using EventType = Bit.Core.Enums.EventType; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.Policies; + +public class VNextSavePolicyCommandTests +{ + [Theory, BitAutoData] + public async Task SaveAsync_NewPolicy_Success([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + // Arrange + var fakePolicyValidationEvent = new FakeSingleOrgValidationEvent(); + fakePolicyValidationEvent.ValidateAsyncMock(Arg.Any(), Arg.Any()).Returns(""); + var sutProvider = SutProviderFactory([ + new FakeSingleOrgDependencyEvent(), + fakePolicyValidationEvent + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var newPolicy = new Policy + { + Type = policyUpdate.Type, + OrganizationId = policyUpdate.OrganizationId, + Enabled = false + }; + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([newPolicy]); + + var creationDate = sutProvider.GetDependency().Start; + + // Act + await sutProvider.Sut.SaveAsync(savePolicyModel); + + // Assert + await fakePolicyValidationEvent.ValidateAsyncMock + .Received(1) + .Invoke(Arg.Any(), Arg.Any()); + + await AssertPolicySavedAsync(sutProvider, policyUpdate); + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(Arg.Is(p => + p.CreationDate == creationDate && + p.RevisionDate == creationDate)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ExistingPolicy_Success( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy currentPolicy) + { + // Arrange + var fakePolicyValidationEvent = new FakeSingleOrgValidationEvent(); + fakePolicyValidationEvent.ValidateAsyncMock(Arg.Any(), Arg.Any()).Returns(""); + var sutProvider = SutProviderFactory([ + new FakeSingleOrgDependencyEvent(), + fakePolicyValidationEvent + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + currentPolicy.OrganizationId = policyUpdate.OrganizationId; + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policyUpdate.OrganizationId, policyUpdate.Type) + .Returns(currentPolicy); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy]); + + // Act + await sutProvider.Sut.SaveAsync(savePolicyModel); + + // Assert + await fakePolicyValidationEvent.ValidateAsyncMock + .Received(1) + .Invoke(Arg.Any(), currentPolicy); + + await AssertPolicySavedAsync(sutProvider, policyUpdate); + + + var revisionDate = sutProvider.GetDependency().Start; + + await sutProvider.GetDependency() + .Received(1) + .UpsertAsync(Arg.Is(p => + p.Id == currentPolicy.Id && + p.OrganizationId == currentPolicy.OrganizationId && + p.Type == currentPolicy.Type && + p.CreationDate == currentPolicy.CreationDate && + p.RevisionDate == revisionDate)); + } + + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest([PolicyUpdate(PolicyType.ActivateAutofill)] PolicyUpdate policyUpdate) + { + // Arrange + var sutProvider = SutProviderFactory(); + var savePolicyModel = new SavePolicyModel(policyUpdate); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(Task.FromResult(null)); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest([PolicyUpdate(PolicyType.ActivateAutofill)] PolicyUpdate policyUpdate) + { + // Arrange + var sutProvider = SutProviderFactory(); + var savePolicyModel = new SavePolicyModel(policyUpdate); + + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(new OrganizationAbility + { + Id = policyUpdate.OrganizationId, + UsePolicies = false + }); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyIsNull_Throws( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var requireSsoPolicy = new Policy + { + Type = PolicyType.RequireSso, + OrganizationId = policyUpdate.OrganizationId, + Enabled = false + }; + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([requireSsoPolicy]); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyNotEnabled_Throws( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy singleOrgPolicy) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var requireSsoPolicy = new Policy + { + Type = PolicyType.RequireSso, + OrganizationId = policyUpdate.OrganizationId, + Enabled = false + }; + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([singleOrgPolicy, requireSsoPolicy]); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Turn on the Single organization policy because it is required for the Require single sign-on authentication policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_RequiredPolicyEnabled_Success( + [PolicyUpdate(PolicyType.RequireSso)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy singleOrgPolicy) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var requireSsoPolicy = new Policy + { + Type = PolicyType.RequireSso, + OrganizationId = policyUpdate.OrganizationId, + Enabled = false + }; + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([singleOrgPolicy, requireSsoPolicy]); + + // Act + await sutProvider.Sut.SaveAsync(savePolicyModel); + + // Assert + await AssertPolicySavedAsync(sutProvider, policyUpdate); + } + + [Theory, BitAutoData] + public async Task SaveAsync_DependentPolicyIsEnabled_Throws( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy]); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Turn off the Require single sign-on authentication policy because it requires the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_MultipleDependentPoliciesAreEnabled_Throws( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso)] Policy requireSsoPolicy, + [Policy(PolicyType.MaximumVaultTimeout)] Policy vaultTimeoutPolicy) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent(), + new FakeVaultTimeoutDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy, vaultTimeoutPolicy]); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Turn off all of the policies that require the Single organization policy", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + [Theory, BitAutoData] + public async Task SaveAsync_DependentPolicyNotEnabled_Success( + [PolicyUpdate(PolicyType.SingleOrg, false)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg)] Policy currentPolicy, + [Policy(PolicyType.RequireSso, false)] Policy requireSsoPolicy) + { + // Arrange + var sutProvider = SutProviderFactory( + [ + new FakeRequireSsoDependencyEvent(), + new FakeSingleOrgDependencyEvent() + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(policyUpdate.OrganizationId) + .Returns([currentPolicy, requireSsoPolicy]); + + // Act + await sutProvider.Sut.SaveAsync(savePolicyModel); + + // Assert + await AssertPolicySavedAsync(sutProvider, policyUpdate); + } + + [Theory, BitAutoData] + public async Task SaveAsync_ThrowsOnValidationError([PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate) + { + // Arrange + var fakePolicyValidationEvent = new FakeSingleOrgValidationEvent(); + fakePolicyValidationEvent.ValidateAsyncMock(Arg.Any(), Arg.Any()).Returns("Validation error!"); + var sutProvider = SutProviderFactory([ + new FakeSingleOrgDependencyEvent(), + fakePolicyValidationEvent + ]); + + var savePolicyModel = new SavePolicyModel(policyUpdate); + + var singleOrgPolicy = new Policy + { + Type = PolicyType.SingleOrg, + OrganizationId = policyUpdate.OrganizationId, + Enabled = false + }; + + ArrangeOrganization(sutProvider, policyUpdate); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(policyUpdate.OrganizationId).Returns([singleOrgPolicy]); + + // Act + var badRequestException = await Assert.ThrowsAsync( + () => sutProvider.Sut.SaveAsync(savePolicyModel)); + + // Assert + Assert.Contains("Validation error!", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + await AssertPolicyNotSavedAsync(sutProvider); + } + + /// + /// Returns a new SutProvider with the PolicyUpdateEvents registered in the Sut. + /// + private static SutProvider SutProviderFactory( + IEnumerable? policyUpdateEvents = null) + { + var policyEventHandlerFactory = Substitute.For(); + var handlers = policyUpdateEvents ?? []; + + // Setup factory to return handlers based on type + policyEventHandlerFactory.GetHandler(Arg.Any()) + .Returns(callInfo => + { + var policyType = callInfo.Arg(); + var handler = handlers.OfType().FirstOrDefault(e => e.Type == policyType); + return handler != null ? OneOf.OneOf.FromT0(handler) : OneOf.OneOf.FromT1(new None()); + }); + + policyEventHandlerFactory.GetHandler(Arg.Any()) + .Returns(callInfo => + { + var policyType = callInfo.Arg(); + var handler = handlers.OfType().FirstOrDefault(e => e.Type == policyType); + return handler != null ? OneOf.OneOf.FromT0(handler) : OneOf.OneOf.FromT1(new None()); + }); + + policyEventHandlerFactory.GetHandler(Arg.Any()) + .Returns(new None()); + + policyEventHandlerFactory.GetHandler(Arg.Any()) + .Returns(new None()); + + return new SutProvider() + .WithFakeTimeProvider() + .SetDependency(handlers) + .SetDependency(policyEventHandlerFactory) + .Create(); + } + + private static void ArrangeOrganization(SutProvider sutProvider, PolicyUpdate policyUpdate) + { + sutProvider.GetDependency() + .GetOrganizationAbilityAsync(policyUpdate.OrganizationId) + .Returns(new OrganizationAbility + { + Id = policyUpdate.OrganizationId, + UsePolicies = true + }); + } + + private static async Task AssertPolicyNotSavedAsync(SutProvider sutProvider) + { + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .UpsertAsync(default!); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogPolicyEventAsync(default, default); + } + + private static async Task AssertPolicySavedAsync(SutProvider sutProvider, PolicyUpdate policyUpdate) + { + await sutProvider.GetDependency().Received(1).UpsertAsync(ExpectedPolicy()); + + await sutProvider.GetDependency().Received(1) + .LogPolicyEventAsync(ExpectedPolicy(), EventType.Policy_Updated); + + return; + + Policy ExpectedPolicy() => Arg.Is( + p => + p.Type == policyUpdate.Type + && p.OrganizationId == policyUpdate.OrganizationId + && p.Enabled == policyUpdate.Enabled + && p.Data == policyUpdate.Data); + } +} diff --git a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs deleted file mode 100644 index f038fe28ef..0000000000 --- a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs +++ /dev/null @@ -1,312 +0,0 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Entities; -using Bit.Core.Enums; -using Bit.Core.Models.Data; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Bit.Test.Common.Helpers; -using Microsoft.Extensions.Logging; -using NSubstitute; -using Xunit; - -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class EventIntegrationHandlerTests -{ - private const string _templateBase = "Date: #Date#, Type: #Type#, UserId: #UserId#"; - private const string _templateWithOrganization = "Org: #OrganizationName#"; - private const string _templateWithUser = "#UserName#, #UserEmail#"; - private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#"; - private static readonly Uri _uri = new Uri("https://localhost"); - private static readonly Uri _uri2 = new Uri("https://example.com"); - private readonly IEventIntegrationPublisher _eventIntegrationPublisher = Substitute.For(); - private readonly ILogger> _logger = - Substitute.For>>(); - - private SutProvider> GetSutProvider( - List configurations) - { - var configurationCache = Substitute.For(); - configurationCache.GetConfigurationDetails(Arg.Any(), - IntegrationType.Webhook, Arg.Any()).Returns(configurations); - - return new SutProvider>() - .SetDependency(configurationCache) - .SetDependency(_eventIntegrationPublisher) - .SetDependency(IntegrationType.Webhook) - .SetDependency(_logger) - .Create(); - } - - private static IntegrationMessage expectedMessage(string template) - { - return new IntegrationMessage() - { - IntegrationType = IntegrationType.Webhook, - MessageId = "TestMessageId", - Configuration = new WebhookIntegrationConfigurationDetails(_uri), - RenderedTemplate = template, - RetryCount = 0, - DelayUntilDate = null - }; - } - - private static List NoConfigurations() - { - return []; - } - - private static List OneConfiguration(string template) - { - var config = Substitute.For(); - config.Configuration = null; - config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); - config.Template = template; - - return [config]; - } - - private static List TwoConfigurations(string template) - { - var config = Substitute.For(); - config.Configuration = null; - config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); - config.Template = template; - var config2 = Substitute.For(); - config2.Configuration = null; - config2.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri2 }); - config2.Template = template; - - return [config, config2]; - } - - private static List InvalidFilterConfiguration() - { - var config = Substitute.For(); - config.Configuration = null; - config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); - config.Template = _templateBase; - config.Filters = "Invalid Configuration!"; - - return [config]; - } - - private static List ValidFilterConfiguration() - { - var config = Substitute.For(); - config.Configuration = null; - config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); - config.Template = _templateBase; - config.Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() { }); - - return [config]; - } - - - [Theory, BitAutoData] - public async Task HandleEventAsync_BaseTemplateNoConfigurations_DoesNothing(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(NoConfigurations()); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( - $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" - ); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( - $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" - ); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - - expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_ActingUserTemplate_LoadsUserFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); - var user = Substitute.For(); - user.Email = "test@example.com"; - user.Name = "Test"; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"{user.Name}, {user.Email}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.ActingUserId ?? Guid.Empty); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_OrganizationTemplate_LoadsOrganizationFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); - var organization = Substitute.For(); - organization.Name = "Test"; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(organization); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"Org: {organization.Name}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.OrganizationId ?? Guid.Empty); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_UserTemplate_LoadsUserFromRepository(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); - var user = Substitute.For(); - user.Email = "test@example.com"; - user.Name = "Test"; - - sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage($"{user.Name}, {user.Email}"); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); - await sutProvider.GetDependency().Received(1).GetByIdAsync(eventMessage.UserId ?? Guid.Empty); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_FilterReturnsFalse_DoesNothing(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(ValidFilterConfiguration()); - sutProvider.GetDependency().EvaluateFilterGroup( - Arg.Any(), Arg.Any()).Returns(false); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_FilterReturnsTrue_PublishesIntegrationMessage(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(ValidFilterConfiguration()); - sutProvider.GetDependency().EvaluateFilterGroup( - Arg.Any(), Arg.Any()).Returns(true); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( - $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" - ); - - Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - } - - [Theory, BitAutoData] - public async Task HandleEventAsync_InvalidFilter_LogsErrorDoesNothing(EventMessage eventMessage) - { - var sutProvider = GetSutProvider(InvalidFilterConfiguration()); - - await sutProvider.Sut.HandleEventAsync(eventMessage); - Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); - _logger.Received(1).Log( - LogLevel.Error, - Arg.Any(), - Arg.Any(), - Arg.Any(), - Arg.Any>()); - } - - [Theory, BitAutoData] - public async Task HandleManyEventsAsync_BaseTemplateNoConfigurations_DoesNothing(List eventMessages) - { - var sutProvider = GetSutProvider(NoConfigurations()); - - await sutProvider.Sut.HandleManyEventsAsync(eventMessages); - Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); - } - - [Theory, BitAutoData] - public async Task HandleManyEventsAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessages(List eventMessages) - { - var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); - - await sutProvider.Sut.HandleManyEventsAsync(eventMessages); - - foreach (var eventMessage in eventMessages) - { - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( - $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" - ); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - } - } - - [Theory, BitAutoData] - public async Task HandleManyEventsAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages( - List eventMessages) - { - var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); - - await sutProvider.Sut.HandleManyEventsAsync(eventMessages); - - foreach (var eventMessage in eventMessages) - { - var expectedMessage = EventIntegrationHandlerTests.expectedMessage( - $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" - ); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - - expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); - } - } -} diff --git a/test/Core.Test/AdminConsole/Services/EventRouteServiceTests.cs b/test/Core.Test/AdminConsole/Services/EventRouteServiceTests.cs deleted file mode 100644 index 1a42d846f2..0000000000 --- a/test/Core.Test/AdminConsole/Services/EventRouteServiceTests.cs +++ /dev/null @@ -1,65 +0,0 @@ -using Bit.Core.Models.Data; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture.Attributes; -using NSubstitute; -using Xunit; - -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class EventRouteServiceTests -{ - private readonly IEventWriteService _broadcastEventWriteService = Substitute.For(); - private readonly IEventWriteService _storageEventWriteService = Substitute.For(); - private readonly IFeatureService _featureService = Substitute.For(); - private readonly EventRouteService Subject; - - public EventRouteServiceTests() - { - Subject = new EventRouteService(_broadcastEventWriteService, _storageEventWriteService, _featureService); - } - - [Theory, BitAutoData] - public async Task CreateAsync_FlagDisabled_EventSentToStorageService(EventMessage eventMessage) - { - _featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(false); - - await Subject.CreateAsync(eventMessage); - - await _broadcastEventWriteService.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any()); - await _storageEventWriteService.Received(1).CreateAsync(eventMessage); - } - - [Theory, BitAutoData] - public async Task CreateAsync_FlagEnabled_EventSentToBroadcastService(EventMessage eventMessage) - { - _featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(true); - - await Subject.CreateAsync(eventMessage); - - await _broadcastEventWriteService.Received(1).CreateAsync(eventMessage); - await _storageEventWriteService.DidNotReceiveWithAnyArgs().CreateAsync(Arg.Any()); - } - - [Theory, BitAutoData] - public async Task CreateManyAsync_FlagDisabled_EventsSentToStorageService(IEnumerable eventMessages) - { - _featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(false); - - await Subject.CreateManyAsync(eventMessages); - - await _broadcastEventWriteService.DidNotReceiveWithAnyArgs().CreateManyAsync(Arg.Any>()); - await _storageEventWriteService.Received(1).CreateManyAsync(eventMessages); - } - - [Theory, BitAutoData] - public async Task CreateManyAsync_FlagEnabled_EventsSentToBroadcastService(IEnumerable eventMessages) - { - _featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations).Returns(true); - - await Subject.CreateManyAsync(eventMessages); - - await _broadcastEventWriteService.Received(1).CreateManyAsync(eventMessages); - await _storageEventWriteService.DidNotReceiveWithAnyArgs().CreateManyAsync(Arg.Any>()); - } -} diff --git a/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs b/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs deleted file mode 100644 index 4e87d13caf..0000000000 --- a/test/Core.Test/AdminConsole/Services/IntegrationConfigurationDetailsCacheServiceTests.cs +++ /dev/null @@ -1,173 +0,0 @@ -#nullable enable - -using System.Text.Json; -using Bit.Core.Enums; -using Bit.Core.Models.Data.Organizations; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.Extensions.Logging; -using NSubstitute; -using NSubstitute.ExceptionExtensions; -using Xunit; - -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class IntegrationConfigurationDetailsCacheServiceTests -{ - private SutProvider GetSutProvider( - List configurations) - { - var configurationRepository = Substitute.For(); - configurationRepository.GetAllConfigurationDetailsAsync().Returns(configurations); - - return new SutProvider() - .SetDependency(configurationRepository) - .Create(); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_SpecificKeyExists_ReturnsExpectedList(OrganizationIntegrationConfigurationDetails config) - { - config.EventType = EventType.Cipher_Created; - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - EventType.Cipher_Created); - Assert.Single(result); - Assert.Same(config, result[0]); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_AllEventsKeyExists_ReturnsExpectedList(OrganizationIntegrationConfigurationDetails config) - { - config.EventType = null; - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - EventType.Cipher_Created); - Assert.Single(result); - Assert.Same(config, result[0]); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_BothSpecificAndAllEventsKeyExists_ReturnsExpectedList( - OrganizationIntegrationConfigurationDetails specificConfig, - OrganizationIntegrationConfigurationDetails allKeysConfig - ) - { - specificConfig.EventType = EventType.Cipher_Created; - allKeysConfig.EventType = null; - allKeysConfig.OrganizationId = specificConfig.OrganizationId; - allKeysConfig.IntegrationType = specificConfig.IntegrationType; - - var sutProvider = GetSutProvider([specificConfig, allKeysConfig]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - specificConfig.OrganizationId, - specificConfig.IntegrationType, - EventType.Cipher_Created); - Assert.Equal(2, result.Count); - Assert.Contains(result, r => r.Template == specificConfig.Template); - Assert.Contains(result, r => r.Template == allKeysConfig.Template); - } - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_KeyMissing_ReturnsEmptyList(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - var result = sutProvider.Sut.GetConfigurationDetails( - Guid.NewGuid(), - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Empty(result); - } - - - - [Theory, BitAutoData] - public async Task GetConfigurationDetails_ReturnsCachedValue_EvenIfRepositoryChanges(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - - var newConfig = JsonSerializer.Deserialize(JsonSerializer.Serialize(config)); - Assert.NotNull(newConfig); - newConfig.Template = "Changed"; - sutProvider.GetDependency().GetAllConfigurationDetailsAsync() - .Returns([newConfig]); - - var result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Single(result); - Assert.NotEqual("Changed", result[0].Template); // should not yet pick up change from repository - - await sutProvider.Sut.RefreshAsync(); // Pick up changes - - result = sutProvider.Sut.GetConfigurationDetails( - config.OrganizationId, - config.IntegrationType, - config.EventType ?? EventType.Cipher_Created); - Assert.Single(result); - Assert.Equal("Changed", result[0].Template); // Should have the new value - } - - [Theory, BitAutoData] - public async Task RefreshAsync_GroupsByCompositeKey(OrganizationIntegrationConfigurationDetails config1) - { - var config2 = JsonSerializer.Deserialize( - JsonSerializer.Serialize(config1))!; - config2.Template = "Another"; - - var sutProvider = GetSutProvider([config1, config2]); - await sutProvider.Sut.RefreshAsync(); - - var results = sutProvider.Sut.GetConfigurationDetails( - config1.OrganizationId, - config1.IntegrationType, - config1.EventType ?? EventType.Cipher_Created); - - Assert.Equal(2, results.Count); - Assert.Contains(results, r => r.Template == config1.Template); - Assert.Contains(results, r => r.Template == config2.Template); - } - - [Theory, BitAutoData] - public async Task RefreshAsync_LogsInformationOnSuccess(OrganizationIntegrationConfigurationDetails config) - { - var sutProvider = GetSutProvider([config]); - await sutProvider.Sut.RefreshAsync(); - - sutProvider.GetDependency>().Received().Log( - LogLevel.Information, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("Refreshed successfully")), - null, - Arg.Any>()); - } - - [Fact] - public async Task RefreshAsync_OnException_LogsError() - { - var sutProvider = GetSutProvider([]); - sutProvider.GetDependency().GetAllConfigurationDetailsAsync() - .Throws(new Exception("Database failure")); - await sutProvider.Sut.RefreshAsync(); - - sutProvider.GetDependency>().Received(1).Log( - LogLevel.Error, - Arg.Any(), - Arg.Is(o => o.ToString()!.Contains("Refresh failed")), - Arg.Any(), - Arg.Any>()); - } -} diff --git a/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs deleted file mode 100644 index aa93567538..0000000000 --- a/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; -using Bit.Core.Services; -using Xunit; - -namespace Bit.Core.Test.Services; - -public class IntegrationHandlerTests -{ - - [Fact] - public async Task HandleAsync_ConvertsJsonToTypedIntegrationMessage() - { - var sut = new TestIntegrationHandler(); - var expected = new IntegrationMessage() - { - Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), - MessageId = "TestMessageId", - IntegrationType = IntegrationType.Webhook, - RenderedTemplate = "Template", - DelayUntilDate = null, - RetryCount = 0 - }; - - var result = await sut.HandleAsync(expected.ToJson()); - var typedResult = Assert.IsType>(result.Message); - - Assert.Equal(expected.Configuration, typedResult.Configuration); - Assert.Equal(expected.RenderedTemplate, typedResult.RenderedTemplate); - Assert.Equal(expected.IntegrationType, typedResult.IntegrationType); - } - - private class TestIntegrationHandler : IntegrationHandlerBase - { - public override Task HandleAsync( - IntegrationMessage message) - { - var result = new IntegrationHandlerResult(success: true, message: message); - return Task.FromResult(result); - } - } -} diff --git a/test/Core.Test/AdminConsole/Services/IntegrationTypeTests.cs b/test/Core.Test/AdminConsole/Services/IntegrationTypeTests.cs index 98cf974df8..134aa17129 100644 --- a/test/Core.Test/AdminConsole/Services/IntegrationTypeTests.cs +++ b/test/Core.Test/AdminConsole/Services/IntegrationTypeTests.cs @@ -1,21 +1,10 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; using Xunit; namespace Bit.Core.Test.Services; public class IntegrationTypeTests { - [Fact] - public void ToRoutingKey_Slack_Succeeds() - { - Assert.Equal("slack", IntegrationType.Slack.ToRoutingKey()); - } - [Fact] - public void ToRoutingKey_Webhook_Succeeds() - { - Assert.Equal("webhook", IntegrationType.Webhook.ToRoutingKey()); - } - [Fact] public void ToRoutingKey_CloudBillingSync_ThrowsException() { @@ -27,4 +16,34 @@ public class IntegrationTypeTests { Assert.Throws(() => IntegrationType.Scim.ToRoutingKey()); } + + [Fact] + public void ToRoutingKey_Slack_Succeeds() + { + Assert.Equal("slack", IntegrationType.Slack.ToRoutingKey()); + } + + [Fact] + public void ToRoutingKey_Webhook_Succeeds() + { + Assert.Equal("webhook", IntegrationType.Webhook.ToRoutingKey()); + } + + [Fact] + public void ToRoutingKey_Hec_Succeeds() + { + Assert.Equal("hec", IntegrationType.Hec.ToRoutingKey()); + } + + [Fact] + public void ToRoutingKey_Datadog_Succeeds() + { + Assert.Equal("datadog", IntegrationType.Datadog.ToRoutingKey()); + } + + [Fact] + public void ToRoutingKey_Teams_Succeeds() + { + Assert.Equal("teams", IntegrationType.Teams.ToRoutingKey()); + } } diff --git a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs index 33f2e78799..43a33cda31 100644 --- a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs @@ -9,6 +9,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -21,8 +22,8 @@ using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Tokens; -using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Fakes; @@ -618,7 +619,7 @@ public class OrganizationServiceTests SetupOrgUserRepositoryCreateManyAsyncMock(organizationUserRepository); SetupOrgUserRepositoryCreateAsyncMock(organizationUserRepository); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); await sutProvider.Sut.InviteUsersAsync(organization.Id, savingUser.Id, systemUser: null, invites); @@ -666,7 +667,7 @@ public class OrganizationServiceTests .SendInvitesAsync(Arg.Any()).ThrowsAsync(); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); await Assert.ThrowsAsync(async () => await sutProvider.Sut.InviteUsersAsync(organization.Id, savingUser.Id, systemUser: null, invites)); @@ -732,7 +733,7 @@ public class OrganizationServiceTests sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscription(organization.Id, seatAdjustment, maxAutoscaleSeats)); @@ -757,7 +758,7 @@ public class OrganizationServiceTests organization.SmSeats = 100; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -837,7 +838,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.EnterpriseMonthly)] public void ValidateSecretsManagerPlan_ThrowsException_WhenNoSecretsManagerSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -854,7 +855,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.Free)] public void ValidateSecretsManagerPlan_ThrowsException_WhenSubtractingSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -871,7 +872,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -890,7 +891,7 @@ public class OrganizationServiceTests [BitAutoData(PlanType.EnterpriseMonthly)] public void ValidateSecretsManagerPlan_ThrowsException_WhenMoreSeatsThanPasswordManagerSeats(PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -912,7 +913,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -930,7 +931,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -952,7 +953,7 @@ public class OrganizationServiceTests PlanType planType, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var signup = new OrganizationUpgrade { UseSecretsManager = true, @@ -1142,7 +1143,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .Received(1) - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(id => id == organization.GatewayCustomerId), Arg.Is(options => options.Email == requestOptionsReturned.Email && options.Description == requestOptionsReturned.Description @@ -1182,7 +1183,7 @@ public class OrganizationServiceTests .GetByIdentifierAsync(Arg.Is(id => id == organization.Identifier)); await stripeAdapter .DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); await organizationRepository .Received(1) .ReplaceAsync(Arg.Is(org => org == organization)); diff --git a/test/Core.Test/AdminConsole/Services/SlackIntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/SlackIntegrationHandlerTests.cs deleted file mode 100644 index dab6c41b61..0000000000 --- a/test/Core.Test/AdminConsole/Services/SlackIntegrationHandlerTests.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Bit.Test.Common.Helpers; -using NSubstitute; -using Xunit; - -namespace Bit.Core.Test.Services; - -[SutProviderCustomize] -public class SlackIntegrationHandlerTests -{ - private readonly ISlackService _slackService = Substitute.For(); - private readonly string _channelId = "C12345"; - private readonly string _token = "xoxb-test-token"; - - private SutProvider GetSutProvider() - { - return new SutProvider() - .SetDependency(_slackService) - .Create(); - } - - [Theory, BitAutoData] - public async Task HandleAsync_SuccessfulRequest_ReturnsSuccess(IntegrationMessage message) - { - var sutProvider = GetSutProvider(); - message.Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token); - - var result = await sutProvider.Sut.HandleAsync(message); - - Assert.True(result.Success); - Assert.Equal(result.Message, message); - - await sutProvider.GetDependency().Received(1).SendSlackMessageByChannelIdAsync( - Arg.Is(AssertHelper.AssertPropertyEqual(_token)), - Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)), - Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)) - ); - } -} diff --git a/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs b/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs index d9df9486b6..aee4af346c 100644 --- a/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs +++ b/test/Core.Test/AdminConsole/Utilities/IntegrationTemplateProcessorTests.cs @@ -83,6 +83,7 @@ public class IntegrationTemplateProcessorTests [Theory] [InlineData("User name is #UserName#")] [InlineData("Email: #UserEmail#")] + [InlineData("User type = #UserType#")] public void TemplateRequiresUser_ContainingKeys_ReturnsTrue(string template) { var result = IntegrationTemplateProcessor.TemplateRequiresUser(template); @@ -102,6 +103,7 @@ public class IntegrationTemplateProcessorTests [Theory] [InlineData("Acting user is #ActingUserName#")] [InlineData("Acting user's email is #ActingUserEmail#")] + [InlineData("Acting user's type is #ActingUserType#")] public void TemplateRequiresActingUser_ContainingKeys_ReturnsTrue(string template) { var result = IntegrationTemplateProcessor.TemplateRequiresActingUser(template); @@ -118,6 +120,25 @@ public class IntegrationTemplateProcessorTests Assert.False(result); } + [Theory] + [InlineData("Group name is #GroupName#!")] + [InlineData("Group: #GroupName#")] + public void TemplateRequiresGroup_ContainingKeys_ReturnsTrue(string template) + { + var result = IntegrationTemplateProcessor.TemplateRequiresGroup(template); + Assert.True(result); + } + + [Theory] + [InlineData("#GroupId#")] // This is on the base class, not fetched, so should be false + [InlineData("No Group Tokens")] + [InlineData("")] + public void TemplateRequiresGroup_EmptyInputOrNoMatchingKeys_ReturnsFalse(string template) + { + var result = IntegrationTemplateProcessor.TemplateRequiresGroup(template); + Assert.False(result); + } + [Theory] [InlineData("Organization: #OrganizationName#")] [InlineData("Welcome to #OrganizationName#")] diff --git a/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs new file mode 100644 index 0000000000..43725d23e0 --- /dev/null +++ b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs @@ -0,0 +1,59 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; +using Bit.Core.Exceptions; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.Utilities; + +public class PolicyDataValidatorTests +{ + [Fact] + public void ValidateAndSerialize_NullData_ReturnsNull() + { + var result = PolicyDataValidator.ValidateAndSerialize(null, PolicyType.MasterPassword); + + Assert.Null(result); + } + + [Fact] + public void ValidateAndSerialize_ValidData_ReturnsSerializedJson() + { + var data = new Dictionary { { "minLength", 12 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minLength\":12", result); + } + + [Fact] + public void ValidateAndSerialize_InvalidDataType_ThrowsBadRequestException() + { + var data = new Dictionary { { "minLength", "not a number" } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + Assert.Contains("minLength", exception.Message); + } + + [Fact] + public void ValidateAndDeserializeMetadata_NullMetadata_ReturnsEmptyMetadataModel() + { + var result = PolicyDataValidator.ValidateAndDeserializeMetadata(null, PolicyType.SingleOrg); + + Assert.IsType(result); + } + + [Fact] + public void ValidateAndDeserializeMetadata_ValidMetadata_ReturnsModel() + { + var metadata = new Dictionary { { "defaultUserCollectionName", "collection name" } }; + + var result = PolicyDataValidator.ValidateAndDeserializeMetadata(metadata, PolicyType.OrganizationDataOwnership); + + Assert.IsType(result); + } +} diff --git a/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs b/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs new file mode 100644 index 0000000000..2b9b5cf194 --- /dev/null +++ b/test/Core.Test/Auth/Attributes/MarketingInitiativeValidationAttributeTests.cs @@ -0,0 +1,70 @@ +using Bit.Core.Auth.Attributes; +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Xunit; + +namespace Bit.Core.Test.Auth.Attributes; + +public class MarketingInitiativeValidationAttributeTests +{ + [Fact] + public void IsValid_NullValue_ReturnsTrue() + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(null); + + Assert.True(actual); + } + + [Theory] + [InlineData(MarketingInitiativeConstants.Premium)] + public void IsValid_AcceptedValue_ReturnsTrue(string value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.True(actual); + } + + [Theory] + [InlineData("invalid")] + [InlineData("")] + [InlineData("Premium")] // case sensitive - capitalized + [InlineData("PREMIUM")] // case sensitive - uppercase + [InlineData("premium ")] // trailing space + [InlineData(" premium")] // leading space + public void IsValid_InvalidStringValue_ReturnsFalse(string value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.False(actual); + } + + [Theory] + [InlineData(123)] // integer + [InlineData(true)] // boolean + [InlineData(45.67)] // double + public void IsValid_NonStringValue_ReturnsFalse(object value) + { + var sut = new MarketingInitiativeValidationAttribute(); + + var actual = sut.IsValid(value); + + Assert.False(actual); + } + + [Fact] + public void ErrorMessage_ContainsAcceptedValues() + { + var sut = new MarketingInitiativeValidationAttribute(); + + var errorMessage = sut.ErrorMessage; + + Assert.NotNull(errorMessage); + Assert.Contains("premium", errorMessage); + Assert.Contains("Marketing initiative type must be one of:", errorMessage); + } +} diff --git a/test/Core.Test/Auth/Entities/AuthRequestTests.cs b/test/Core.Test/Auth/Entities/AuthRequestTests.cs new file mode 100644 index 0000000000..9efeb1ded1 --- /dev/null +++ b/test/Core.Test/Auth/Entities/AuthRequestTests.cs @@ -0,0 +1,224 @@ +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Xunit; + +namespace Bit.Core.Test.Auth.Entities; + +public class AuthRequestTests +{ + [Fact] + public void IsValidForAuthentication_WithValidRequest_ReturnsTrue() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.True(result); + } + + [Fact] + public void IsValidForAuthentication_WithWrongUserId_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var differentUserId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(differentUserId, accessCode); + + // Assert + Assert.False(result, "Auth request should not validate for a different user"); + } + + [Fact] + public void IsValidForAuthentication_WithWrongAccessCode_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = "correct-code" + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, "wrong-code"); + + // Assert + Assert.False(result); + } + + [Fact] + public void IsValidForAuthentication_WithoutResponseDate_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = null, // Not responded to + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Unanswered auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithApprovedFalse_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = false, // Denied + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Denied auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithApprovedNull_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = null, // Pending + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Pending auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithExpiredRequest_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow.AddMinutes(-20), // Expired (15 min timeout) + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Expired auth requests should not be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithWrongType_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.Unlock, // Wrong type + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = null, + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Only AuthenticateAndUnlock type should be valid"); + } + + [Fact] + public void IsValidForAuthentication_WithAlreadyUsed_ReturnsFalse() + { + // Arrange + var userId = Guid.NewGuid(); + var accessCode = "test-access-code"; + var authRequest = new AuthRequest + { + UserId = userId, + Type = AuthRequestType.AuthenticateAndUnlock, + ResponseDate = DateTime.UtcNow, + Approved = true, + CreationDate = DateTime.UtcNow, + AuthenticationDate = DateTime.UtcNow, // Already used + AccessCode = accessCode + }; + + // Act + var result = authRequest.IsValidForAuthentication(userId, accessCode); + + // Assert + Assert.False(result, "Auth requests should only be valid for one-time use"); + } +} diff --git a/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs b/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs new file mode 100644 index 0000000000..b78e96e91e --- /dev/null +++ b/test/Core.Test/Auth/Models/Api/Request/Accounts/MarketingInitiativeConstantsSnapshotTests.cs @@ -0,0 +1,18 @@ +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Xunit; + +namespace Bit.Core.Test.Auth.Models.Api.Request.Accounts; + +/// +/// Snapshot tests to ensure the string constants in do not change unintentionally. +/// If you intentionally change any of these values, please update the tests to reflect the new expected values. +/// +public class MarketingInitiativeConstantsSnapshotTests +{ + [Fact] + public void MarketingInitiativeConstants_HaveCorrectValues() + { + // Assert + Assert.Equal("premium", MarketingInitiativeConstants.Premium); + } +} diff --git a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs index 5da0e78422..9c95930c18 100644 --- a/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs +++ b/test/Core.Test/Auth/Services/AuthRequestServiceTests.cs @@ -467,10 +467,9 @@ public class AuthRequestServiceTests Arg.Any(), Arg.Any()); - var expectedLogMessage = "There are no admin emails to send to."; sutProvider.GetDependency>() .Received(1) - .LogWarning(expectedLogMessage); + .LogWarning("There are no admin emails to send to."); } /// diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs index 7beb772b95..2f4d00a7fa 100644 --- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; @@ -339,29 +340,75 @@ public class SsoConfigServiceTests await sutProvider.Sut.SaveAsync(ssoConfig, organization); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.SingleOrg && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.SingleOrg && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.ResetPassword && - t.GetDataModel().AutoEnrollEnabled && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.ResetPassword && + t.PolicyUpdate.GetDataModel().AutoEnrollEnabled && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.RequireSso && - t.OrganizationId == organization.Id && - t.Enabled) + Arg.Is(t => t.PolicyUpdate.Type == PolicyType.RequireSso && + t.PolicyUpdate.OrganizationId == organization.Id && + t.PolicyUpdate.Enabled) ); await sutProvider.GetDependency().ReceivedWithAnyArgs() .UpsertAsync(default); } + + [Theory, BitAutoData] + public async Task SaveAsync_Tde_UsesVNextSavePolicyCommand( + SutProvider sutProvider, Organization organization) + { + var ssoConfig = new SsoConfig + { + Id = default, + Data = new SsoConfigurationData + { + MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption, + }.Serialize(), + Enabled = true, + OrganizationId = organization.Id, + }; + + await sutProvider.Sut.SaveAsync(ssoConfig, organization); + + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.Type == PolicyType.SingleOrg && + m.PolicyUpdate.OrganizationId == organization.Id && + m.PolicyUpdate.Enabled && + m.PerformedBy is SystemUser)); + + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.Type == PolicyType.ResetPassword && + m.PolicyUpdate.GetDataModel().AutoEnrollEnabled && + m.PolicyUpdate.OrganizationId == organization.Id && + m.PolicyUpdate.Enabled && + m.PerformedBy is SystemUser)); + + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.Type == PolicyType.RequireSso && + m.PolicyUpdate.OrganizationId == organization.Id && + m.PolicyUpdate.Enabled && + m.PerformedBy is SystemUser)); + + await sutProvider.GetDependency().ReceivedWithAnyArgs() + .UpsertAsync(default); + } } diff --git a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs index b19ae47cfc..ae669398c5 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs @@ -7,6 +7,7 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.UserFeatures.Registration.Implementations; +using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; @@ -37,6 +38,12 @@ public class RegisterUserCommandTests public async Task RegisterUser_Succeeds(SutProvider sutProvider, User user) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user) .Returns(IdentityResult.Success); @@ -61,6 +68,12 @@ public class RegisterUserCommandTests public async Task RegisterUser_WhenCreateUserFails_ReturnsIdentityResultFailed(SutProvider sutProvider, User user) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user) .Returns(IdentityResult.Failed()); @@ -80,6 +93,120 @@ public class RegisterUserCommandTests .SendWelcomeEmailAsync(Arg.Any()); } + // ----------------------------------------------------------------------------------------------- + // RegisterSSOAutoProvisionedUserAsync tests + // ----------------------------------------------------------------------------------------------- + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_Success( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Id = Guid.NewGuid(); + organization.Id = Guid.NewGuid(); + organization.Name = "Test Organization"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_UserRegistrationFails_ReturnsFailedResult( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + var expectedError = new IdentityError(); + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Failed(expectedError)); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.False(result.Succeeded); + Assert.Contains(expectedError, result.Errors); + await sutProvider.GetDependency() + .DidNotReceive() + .SendOrganizationUserWelcomeEmailAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData(PlanType.EnterpriseAnnually)] + [BitAutoData(PlanType.EnterpriseMonthly)] + [BitAutoData(PlanType.TeamsAnnually)] + public async Task RegisterSSOAutoProvisionedUserAsync_EnterpriseOrg_SendsOrganizationWelcomeEmail( + PlanType planType, + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = planType; + organization.Name = "Enterprise Org"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserWelcomeEmailAsync(user, organization.Name); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_FeatureFlagDisabled_SendsLegacyWelcomeEmail( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(false); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendWelcomeEmailAsync(user); + } + // ----------------------------------------------------------------------------------------------- // RegisterUserWithOrganizationInviteToken tests // ----------------------------------------------------------------------------------------------- @@ -301,6 +428,138 @@ public class RegisterUserCommandTests Assert.Equal(expectedErrorMessage, exception.Message); } + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromDifferentOrg_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + var blockingOrganizationId = Guid.NewGuid(); // Different org that has the domain blocked + orgUser.OrganizationId = Guid.NewGuid(); // The org they're trying to join + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns(orgUser); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Mock the new overload that excludes the organization - it should return true (domain IS blocked by another org) + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", orgUser.OrganizationId) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromSameOrg_Succeeds( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@company-domain.com"; + user.ReferenceData = null; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + // The organization owns the domain and is trying to invite the user + orgUser.OrganizationId = Guid.NewGuid(); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns(orgUser); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Mock the new overload - it should return false (domain is NOT blocked by OTHER orgs) + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", orgUser.OrganizationId) + .Returns(false); + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", orgUser.OrganizationId); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationInviteToken_WithValidTokenButNullOrgUser_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + { + // Arrange + user.Email = "user@example.com"; + orgUser.Email = user.Email; + orgUser.Id = orgUserId; + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Mock GetByIdAsync to return null - simulating a deleted or non-existent organization user + sutProvider.GetDependency() + .GetByIdAsync(orgUserId) + .Returns((OrganizationUser)null); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId)); + Assert.Equal("Invalid organization user invitation.", exception.Message); + + // Verify that GetByIdAsync was called + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(orgUserId); + + // Verify that user creation was never attempted + await sutProvider.GetDependency() + .DidNotReceive() + .CreateUserAsync(Arg.Any(), Arg.Any()); + } + // ----------------------------------------------------------------------------------------------- // RegisterUserViaEmailVerificationToken tests // ----------------------------------------------------------------------------------------------- @@ -310,6 +569,12 @@ public class RegisterUserCommandTests public async Task RegisterUserViaEmailVerificationToken_Succeeds(SutProvider sutProvider, User user, string masterPasswordHash, string emailVerificationToken, bool receiveMarketingMaterials) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(emailVerificationToken, out Arg.Any()) .Returns(callInfo => @@ -342,6 +607,12 @@ public class RegisterUserCommandTests public async Task RegisterUserViaEmailVerificationToken_InvalidToken_ThrowsBadRequestException(SutProvider sutProvider, User user, string masterPasswordHash, string emailVerificationToken, bool receiveMarketingMaterials) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(emailVerificationToken, out Arg.Any()) .Returns(callInfo => @@ -380,6 +651,12 @@ public class RegisterUserCommandTests string orgSponsoredFreeFamilyPlanInviteToken) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) .Returns((true, new OrganizationSponsorship())); @@ -409,6 +686,12 @@ public class RegisterUserCommandTests string masterPasswordHash, string orgSponsoredFreeFamilyPlanInviteToken) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) .Returns((false, new OrganizationSponsorship())); @@ -446,9 +729,14 @@ public class RegisterUserCommandTests EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; emergencyAccess.Email = user.Email; emergencyAccess.Id = acceptEmergencyAccessId; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) .Returns(callInfo => @@ -482,9 +770,14 @@ public class RegisterUserCommandTests string masterPasswordHash, EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; emergencyAccess.Email = "wrong@email.com"; emergencyAccess.Id = acceptEmergencyAccessId; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency>() .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) .Returns(callInfo => @@ -525,6 +818,8 @@ public class RegisterUserCommandTests User user, string masterPasswordHash, Guid providerUserId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + // Start with plaintext var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; @@ -547,6 +842,10 @@ public class RegisterUserCommandTests sutProvider.GetDependency() .OrganizationInviteExpirationHours.Returns(120); // 5 days + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + sutProvider.GetDependency() .CreateUserAsync(user, masterPasswordHash) .Returns(IdentityResult.Success); @@ -576,6 +875,8 @@ public class RegisterUserCommandTests User user, string masterPasswordHash, Guid providerUserId) { // Arrange + user.Email = $"test+{Guid.NewGuid()}@example.com"; + // Start with plaintext var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; @@ -598,6 +899,10 @@ public class RegisterUserCommandTests sutProvider.GetDependency() .OrganizationInviteExpirationHours.Returns(120); // 5 days + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + // Using sutProvider in the parameters of the function means that the constructor has already run for the // command so we have to recreate it in order for our mock overrides to be used. sutProvider.Create(); @@ -646,5 +951,521 @@ public class RegisterUserCommandTests Assert.Equal("Open registration has been disabled by the system administrator.", result.Message); } + // ----------------------------------------------------------------------------------------------- + // Domain blocking tests (BlockClaimedDomainAccountCreation policy) + // ----------------------------------------------------------------------------------------------- + [Theory] + [BitAutoData] + public async Task RegisterUser_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUser(user)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + + // Verify user creation was never attempted + await sutProvider.GetDependency() + .DidNotReceive() + .CreateUserAsync(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task RegisterUser_AllowedDomain_Succeeds( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "user@allowed-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowed-domain.com") + .Returns(false); + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterUser(user); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowed-domain.com"); + } + + // SendWelcomeEmail tests + // ----------------------------------------------------------------------------------------------- + [Theory] + [BitAutoData(PlanType.FamiliesAnnually)] + [BitAutoData(PlanType.FamiliesAnnually2019)] + [BitAutoData(PlanType.FamiliesAnnually2025)] + [BitAutoData(PlanType.Free)] + public async Task SendWelcomeEmail_FamilyOrg_SendsFamilyWelcomeEmail( + PlanType planType, + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = planType; + organization.Name = "Family Org"; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, organization.Name); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaEmailVerificationToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string emailVerificationToken, bool receiveMarketingMaterials) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(emailVerificationToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new RegistrationEmailVerificationTokenable(user.Email, user.Name, receiveMarketingMaterials); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaEmailVerificationToken(user, masterPasswordHash, emailVerificationToken)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string orgSponsoredFreeFamilyPlanInviteToken) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency() + .ValidateRedemptionTokenAsync(orgSponsoredFreeFamilyPlanInviteToken, user.Email) + .Returns((true, new OrganizationSponsorship())); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken(user, masterPasswordHash, orgSponsoredFreeFamilyPlanInviteToken)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + emergencyAccess.Email = user.Email; + emergencyAccess.Id = acceptEmergencyAccessId; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(acceptEmergencyAccessInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new EmergencyAccessInviteTokenable(emergencyAccess, 10); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaAcceptEmergencyAccessInviteToken(user, masterPasswordHash, acceptEmergencyAccessInviteToken, acceptEmergencyAccessId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaProviderInviteToken_BlockedDomain_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, Guid providerUserId) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + // Start with plaintext + var nowMillis = CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow); + var decryptedProviderInviteToken = $"ProviderUserInvite {providerUserId} {user.Email} {nowMillis}"; + + // Get the byte array of the plaintext + var decryptedProviderInviteTokenByteArray = Encoding.UTF8.GetBytes(decryptedProviderInviteToken); + + // Base64 encode the byte array (this is passed to protector.protect(bytes)) + var base64EncodedProviderInvToken = WebEncoders.Base64UrlEncode(decryptedProviderInviteTokenByteArray); + + var mockDataProtector = Substitute.For(); + + // Given any byte array, just return the decryptedProviderInviteTokenByteArray (sidestepping any actual encryption) + mockDataProtector.Unprotect(Arg.Any()).Returns(decryptedProviderInviteTokenByteArray); + + sutProvider.GetDependency() + .CreateProtector("ProviderServiceDataProtector") + .Returns(mockDataProtector); + + sutProvider.GetDependency() + .OrganizationInviteExpirationHours.Returns(120); // 5 days + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com") + .Returns(true); + + // Using sutProvider in the parameters of the function means that the constructor has already run for the + // command so we have to recreate it in order for our mock overrides to be used. + sutProvider.Create(); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaProviderInviteToken(user, masterPasswordHash, base64EncodedProviderInvToken, providerUserId)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + // ----------------------------------------------------------------------------------------------- + // Invalid email format tests + // ----------------------------------------------------------------------------------------------- + + [Theory] + [BitAutoData] + public async Task RegisterUser_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, User user) + { + // Arrange + user.Email = "invalid-email-format"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUser(user)); + Assert.Equal("Invalid email address format.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task RegisterUserViaEmailVerificationToken_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, User user, string masterPasswordHash, + string emailVerificationToken, bool receiveMarketingMaterials) + { + // Arrange + user.Email = "invalid-email-format"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency>() + .TryUnprotect(emailVerificationToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = new RegistrationEmailVerificationTokenable(user.Email, user.Name, receiveMarketingMaterials); + return true; + }); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterUserViaEmailVerificationToken(user, masterPasswordHash, emailVerificationToken)); + Assert.Equal("Invalid email address format.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task SendWelcomeEmail_OrganizationNull_SendsIndividualWelcomeEmail( + User user, + OrganizationUser orgUser, + string orgInviteToken, + string masterPasswordHash, + SutProvider sutProvider) + { + // Arrange + user.ReferenceData = null; + orgUser.Email = user.Email; + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.Id) + .Returns(orgUser); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns((Policy)null); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.OrganizationId) + .Returns((Organization)null); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendIndividualUserWelcomeEmailAsync(user); + } + + [Theory] + [BitAutoData] + public async Task SendWelcomeEmail_OrganizationDisplayNameNull_SendsIndividualWelcomeEmail( + User user, + SutProvider sutProvider) + { + // Arrange + Organization organization = new Organization + { + Name = null + }; + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns((OrganizationUser)null); + + // Act + await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendIndividualUserWelcomeEmailAsync(user); + } + + [Theory] + [BitAutoData] + public async Task GetOrganizationWelcomeEmailDetailsAsync_HappyPath_ReturnsOrganizationWelcomeEmailDetails( + Organization organization, + User user, + OrganizationUser orgUser, + string masterPasswordHash, + string orgInviteToken, + SutProvider sutProvider) + { + // Arrange + user.ReferenceData = null; + orgUser.Email = user.Email; + organization.PlanType = PlanType.EnterpriseAnnually; + + sutProvider.GetDependency() + .CreateUserAsync(user, masterPasswordHash) + .Returns(IdentityResult.Success); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.Id) + .Returns(orgUser); + + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns((Policy)null); + + sutProvider.GetDependency() + .GetByIdAsync(orgUser.OrganizationId) + .Returns(organization); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.MjmlWelcomeEmailTemplates) + .Returns(true); + + var orgInviteTokenable = new OrgUserInviteTokenable(orgUser); + + sutProvider.GetDependency>() + .TryUnprotect(orgInviteToken, out Arg.Any()) + .Returns(callInfo => + { + callInfo[1] = orgInviteTokenable; + return true; + }); + + // Act + var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUser.Id); + + // Assert + Assert.True(result.Succeeded); + + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(orgUser.OrganizationId); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserWelcomeEmailAsync(user, organization.DisplayName()); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithBlockedDomain_ThrowsException( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@blocked-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", organization.Id) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithOwnClaimedDomain_Succeeds( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@company-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Domain is claimed by THIS organization, so it should be allowed + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("company-domain.com", organization.Id) + .Returns(false); // Not blocked because organization.Id is excluded + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } + + [Theory, BitAutoData] + public async Task RegisterSSOAutoProvisionedUserAsync_WithNonClaimedDomain_Succeeds( + User user, + Organization organization, + SutProvider sutProvider) + { + // Arrange + user.Email = "user@unclaimed-domain.com"; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("unclaimed-domain.com", organization.Id) + .Returns(false); // Domain is not claimed by any org + + sutProvider.GetDependency() + .CreateUserAsync(user) + .Returns(IdentityResult.Success); + + // Act + var result = await sutProvider.Sut.RegisterSSOAutoProvisionedUserAsync(user, organization); + + // Assert + Assert.True(result.Succeeded); + await sutProvider.GetDependency() + .Received(1) + .CreateUserAsync(user); + } } diff --git a/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs index f4f620f8a9..91e8351d2c 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/SendVerificationEmailForRegistrationCommandTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.Auth.Models.Business.Tokenables; +using Bit.Core.Auth.Models.Api.Request.Accounts; +using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.UserFeatures.Registration.Implementations; using Bit.Core.Entities; using Bit.Core.Exceptions; @@ -21,6 +22,43 @@ public class SendVerificationEmailForRegistrationCommandTests [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsNewUserAndEnableEmailVerificationTrue_SendsEmailAndReturnsNull(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + + sutProvider.GetDependency() + .GetByEmailAsync(email) + .ReturnsNull(); + + sutProvider.GetDependency() + .EnableEmailVerification = true; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + + var mockedToken = "token"; + sutProvider.GetDependency>() + .Protect(Arg.Any()) + .Returns(mockedToken); + + // Act + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SendRegistrationVerificationEmailAsync(email, mockedToken, null); + Assert.Null(result); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenFromMarketingIsPremium_SendsEmailWithMarketingParameterAndReturnsNull(SutProvider sutProvider, string email, string name, bool receiveMarketingEmails) { // Arrange @@ -34,31 +72,35 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; - sutProvider.GetDependency() - .SendRegistrationVerificationEmailAsync(email, Arg.Any()) - .Returns(Task.CompletedTask); + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); + var fromMarketing = MarketingInitiativeConstants.Premium; + // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, fromMarketing); // Assert await sutProvider.GetDependency() .Received(1) - .SendRegistrationVerificationEmailAsync(email, mockedToken); + .SendRegistrationVerificationEmailAsync(email, mockedToken, fromMarketing); Assert.Null(result); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsExistingUserAndEnableEmailVerificationTrue_ReturnsNull(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .Returns(new User()); @@ -69,27 +111,33 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); // Assert await sutProvider.GetDependency() .DidNotReceive() - .SendRegistrationVerificationEmailAsync(email, mockedToken); + .SendRegistrationVerificationEmailAsync(email, mockedToken, null); Assert.Null(result); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsNewUserAndEnableEmailVerificationFalse_ReturnsToken(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .ReturnsNull(); @@ -100,13 +148,17 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + var mockedToken = "token"; sutProvider.GetDependency>() .Protect(Arg.Any()) .Returns(mockedToken); // Act - var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails); + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); // Assert Assert.Equal(mockedToken, result); @@ -122,15 +174,17 @@ public class SendVerificationEmailForRegistrationCommandTests .DisableUserRegistration = true; // Act & Assert - await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); } [Theory] [BitAutoData] public async Task SendVerificationEmailForRegistrationCommand_WhenIsExistingUserAndEnableEmailVerificationFalse_ThrowsBadRequestException(SutProvider sutProvider, - string email, string name, bool receiveMarketingEmails) + string name, bool receiveMarketingEmails) { // Arrange + var email = $"test+{Guid.NewGuid()}@example.com"; + sutProvider.GetDependency() .GetByEmailAsync(email) .Returns(new User()); @@ -138,8 +192,15 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .EnableEmailVerification = false; + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync(Arg.Any()) + .Returns(false); + // Act & Assert - await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); } [Theory] @@ -150,7 +211,7 @@ public class SendVerificationEmailForRegistrationCommandTests sutProvider.GetDependency() .DisableUserRegistration = false; - await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run(null, name, receiveMarketingEmails)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run(null, name, receiveMarketingEmails, null)); } [Theory] @@ -160,6 +221,90 @@ public class SendVerificationEmailForRegistrationCommandTests { sutProvider.GetDependency() .DisableUserRegistration = false; - await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run("", name, receiveMarketingEmails)); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.Run("", name, receiveMarketingEmails, null)); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenBlockedDomain_ThrowsBadRequestException(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@blockedcompany.com"; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blockedcompany.com") + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); + Assert.Equal("This email address is claimed by an organization using Bitwarden.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_WhenAllowedDomain_Succeeds(SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = $"test+{Guid.NewGuid()}@allowedcompany.com"; + + sutProvider.GetDependency() + .GetByEmailAsync(email) + .ReturnsNull(); + + sutProvider.GetDependency() + .EnableEmailVerification = false; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + sutProvider.GetDependency() + .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("allowedcompany.com") + .Returns(false); + + var mockedToken = "token"; + sutProvider.GetDependency>() + .Protect(Arg.Any()) + .Returns(mockedToken); + + // Act + var result = await sutProvider.Sut.Run(email, name, receiveMarketingEmails, null); + + // Assert + Assert.Equal(mockedToken, result); + } + + [Theory] + [BitAutoData] + public async Task SendVerificationEmailForRegistrationCommand_InvalidEmailFormat_ThrowsBadRequestException( + SutProvider sutProvider, + string name, bool receiveMarketingEmails) + { + // Arrange + var email = "invalid-email-format"; + + sutProvider.GetDependency() + .DisableUserRegistration = false; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.BlockClaimedDomainAccountCreation) + .Returns(true); + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + sutProvider.Sut.Run(email, name, receiveMarketingEmails, null)); + Assert.Equal("Invalid email address format.", exception.Message); } } diff --git a/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs new file mode 100644 index 0000000000..2b448ba79f --- /dev/null +++ b/test/Core.Test/Auth/UserFeatures/Sso/UserSsoOrganizationIdentifierQueryTests.cs @@ -0,0 +1,275 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Auth.Sso; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Auth.UserFeatures.Sso; + +[SutProviderCustomize] +public class UserSsoOrganizationIdentifierQueryTests +{ + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasSingleConfirmedOrganization_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = "test-org-identifier"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("test-org-identifier", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasNoOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId) + { + // Arrange + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(Array.Empty()); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMultipleConfirmedOrganizations_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser1, + OrganizationUser organizationUser2) + { + // Arrange + organizationUser1.UserId = userId; + organizationUser1.Status = OrganizationUserStatusType.Confirmed; + organizationUser2.UserId = userId; + organizationUser2.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser1, organizationUser2]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory] + [BitAutoData(OrganizationUserStatusType.Invited)] + [BitAutoData(OrganizationUserStatusType.Accepted)] + [BitAutoData(OrganizationUserStatusType.Revoked)] + public async Task GetSsoOrganizationIdentifierAsync_UserHasOnlyInvitedOrganization_ReturnsNull( + OrganizationUserStatusType status, + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = status; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .DidNotReceive() + .GetByIdAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_UserHasMixedStatusOrganizations_OnlyOneConfirmed_ReturnsIdentifier( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser confirmedOrgUser, + OrganizationUser invitedOrgUser, + OrganizationUser revokedOrgUser) + { + // Arrange + confirmedOrgUser.UserId = userId; + confirmedOrgUser.OrganizationId = organization.Id; + confirmedOrgUser.Status = OrganizationUserStatusType.Confirmed; + + invitedOrgUser.UserId = userId; + invitedOrgUser.Status = OrganizationUserStatusType.Invited; + + revokedOrgUser.UserId = userId; + revokedOrgUser.Status = OrganizationUserStatusType.Revoked; + + organization.Identifier = "mixed-status-org"; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { confirmedOrgUser, invitedOrgUser, revokedOrgUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal("mixed-status-org", result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationNotFound_ReturnsNull( + SutProvider sutProvider, + Guid userId, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns([organizationUser]); + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.OrganizationId) + .Returns((Organization)null); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organizationUser.OrganizationId); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsNull_ReturnsNull( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = null; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Null(result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } + + [Theory, BitAutoData] + public async Task GetSsoOrganizationIdentifierAsync_OrganizationIdentifierIsEmpty_ReturnsEmpty( + SutProvider sutProvider, + Guid userId, + Organization organization, + OrganizationUser organizationUser) + { + // Arrange + organizationUser.UserId = userId; + organizationUser.OrganizationId = organization.Id; + organizationUser.Status = OrganizationUserStatusType.Confirmed; + organization.Identifier = string.Empty; + + sutProvider.GetDependency() + .GetManyByUserAsync(userId) + .Returns(new[] { organizationUser }); + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + // Act + var result = await sutProvider.Sut.GetSsoOrganizationIdentifierAsync(userId); + + // Assert + Assert.Equal(string.Empty, result); + await sutProvider.GetDependency() + .Received(1) + .GetManyByUserAsync(userId); + await sutProvider.GetDependency() + .Received(1) + .GetByIdAsync(organization.Id); + } +} diff --git a/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs b/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs index adeac45d06..3a98fb44fb 100644 --- a/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs +++ b/test/Core.Test/Auth/UserFeatures/TwoFactorAuth/TwoFactorIsEnabledQueryTests.cs @@ -1,10 +1,13 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth; +using Bit.Core.Billing.Premium.Queries; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -404,6 +407,277 @@ public class TwoFactorIsEnabledQueryTests .GetCalculatedPremiumAsync(default); } + [Theory] + [BitAutoData((IEnumerable)null)] + [BitAutoData([])] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsEmpty( + IEnumerable userIds, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert + Assert.Empty(result); + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Duo)] + [BitAutoData(TwoFactorProviderType.YubiKey)] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithMixedScenarios_ReturnsCorrectResults( + TwoFactorProviderType premiumProviderType, + SutProvider sutProvider, + User user1, + User user2, + User user3) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + // User 1: Non-premium provider → 2FA enabled + user1.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } } + }); + + // User 2: Premium provider + has premium → 2FA enabled + user2.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + // User 3: Premium provider + no premium → 2FA disabled + user3.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + var premiumStatus = new Dictionary + { + { user2.Id, true }, + { user3.Id, false } + }; + + sutProvider.GetDependency() + .GetManyAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))) + .Returns(premiumStatus); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert + Assert.Contains(result, res => res.userId == user1.Id && res.twoFactorIsEnabled == true); // Non-premium provider + Assert.Contains(result, res => res.userId == user2.Id && res.twoFactorIsEnabled == true); // Premium + has premium + Assert.Contains(result, res => res.userId == user3.Id && res.twoFactorIsEnabled == false); // Premium + no premium + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Duo)] + [BitAutoData(TwoFactorProviderType.YubiKey)] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_OnlyChecksPremiumAccessForUsersWhoNeedIt( + TwoFactorProviderType premiumProviderType, + SutProvider sutProvider, + User user1, + User user2, + User user3) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + // User 1: Has non-premium provider - should NOT trigger premium check + user1.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Authenticator, new TwoFactorProvider { Enabled = true } } + }); + + // User 2 & 3: Have only premium providers - SHOULD trigger premium check + user2.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + user3.SetTwoFactorProviders(new Dictionary + { + { premiumProviderType, new TwoFactorProvider { Enabled = true } } + }); + + var premiumStatus = new Dictionary + { + { user2.Id, true }, + { user3.Id, false } + }; + + sutProvider.GetDependency() + .GetManyAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))) + .Returns(premiumStatus); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(userIds); + + // Assert - Verify optimization: premium checked ONLY for users 2 and 3 (not user 1) + await sutProvider.GetDependency() + .Received(1) + .HasPremiumAccessAsync(Arg.Is>(ids => + ids.Count() == 2 && ids.Contains(user2.Id) && ids.Contains(user3.Id))); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoUserIds_ReturnsAllTwoFactorDisabled( + SutProvider sutProvider, + List users) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + foreach (var user in users) + { + user.UserId = null; + } + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(users); + + // Assert + foreach (var user in users) + { + Assert.Contains(result, res => res.user.Equals(user) && res.twoFactorIsEnabled == false); + } + + // No UserIds were supplied so no calls to the UserRepository should have been made + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetManyAsync(default); + } + + [Theory] + [BitAutoData(TwoFactorProviderType.Authenticator, true)] // Non-premium provider + [BitAutoData(TwoFactorProviderType.Duo, true)] // Premium provider with premium access + [BitAutoData(TwoFactorProviderType.YubiKey, false)] // Premium provider without premium access + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_SingleUser_VariousScenarios( + TwoFactorProviderType providerType, + bool hasPremiumAccess, + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.SetTwoFactorProviders(new Dictionary + { + { providerType, new TwoFactorProvider { Enabled = true } } + }); + + sutProvider.GetDependency() + .HasPremiumAccessAsync(user.Id) + .Returns(hasPremiumAccess); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + var requiresPremium = TwoFactorProvider.RequiresPremium(providerType); + var expectedResult = !requiresPremium || hasPremiumAccess; + Assert.Equal(expectedResult, result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNoEnabledProviders_ReturnsFalse( + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.SetTwoFactorProviders(new Dictionary + { + { TwoFactorProviderType.Email, new TwoFactorProvider { Enabled = false } } + }); + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + Assert.False(result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_WithNullProviders_ReturnsFalse( + SutProvider sutProvider, + User user) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + user.TwoFactorProviders = null; + + // Act + var result = await sutProvider.Sut.TwoFactorIsEnabledAsync(user); + + // Assert + Assert.False(result); + } + + [Theory] + [BitAutoData] + public async Task TwoFactorIsEnabledAsync_WhenPremiumAccessQueryEnabled_UserNotFound_ThrowsNotFoundException( + SutProvider sutProvider, + Guid userId) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PremiumAccessQuery) + .Returns(true); + + var testUser = new TestTwoFactorProviderUser + { + Id = userId, + TwoFactorProviders = null + }; + + sutProvider.GetDependency() + .GetByIdAsync(userId) + .Returns((User)null); + + // Act & Assert + await Assert.ThrowsAsync( + async () => await sutProvider.Sut.TwoFactorIsEnabledAsync(testUser)); + } + private class TestTwoFactorProviderUser : ITwoFactorProvidersUser { public Guid? Id { get; set; } @@ -418,10 +692,5 @@ public class TwoFactorIsEnabledQueryTests { return Id; } - - public bool GetPremium() - { - return Premium; - } } } diff --git a/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs index a30e5e896c..1a4f92a224 100644 --- a/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs +++ b/test/Core.Test/Billing/Extensions/InvoiceExtensionsTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.Billing.Extensions; +using System.Globalization; +using Bit.Core.Billing.Extensions; using Stripe; using Xunit; @@ -294,7 +295,8 @@ public class InvoiceExtensionsTests Amount = 600 } ); - invoice.Tax = 120; // $1.20 in cents + + invoice.TotalTaxes = [new InvoiceTotalTax { Amount = 120 }]; // $1.20 in cents var subscription = new Subscription(); // Act @@ -318,7 +320,7 @@ public class InvoiceExtensionsTests Amount = 600 } ); - invoice.Tax = null; + invoice.TotalTaxes = []; var subscription = new Subscription(); // Act @@ -341,7 +343,7 @@ public class InvoiceExtensionsTests Amount = 600 } ); - invoice.Tax = 0; + invoice.TotalTaxes = [new InvoiceTotalTax { Amount = 0 }]; var subscription = new Subscription(); // Act @@ -355,9 +357,18 @@ public class InvoiceExtensionsTests [Fact] public void FormatForProvider_ComplexScenario_HandlesAllLineTypes() { - // Arrange - var lineItems = new StripeList(); - lineItems.Data = new List + // Set culture to en-US to ensure consistent decimal formatting in tests + // This ensures tests pass on all machines regardless of system locale + var originalCulture = Thread.CurrentThread.CurrentCulture; + var originalUICulture = Thread.CurrentThread.CurrentUICulture; + try + { + Thread.CurrentThread.CurrentCulture = new CultureInfo("en-US"); + Thread.CurrentThread.CurrentUICulture = new CultureInfo("en-US"); + + // Arrange + var lineItems = new StripeList(); + lineItems.Data = new List { new InvoiceLineItem { @@ -371,23 +382,29 @@ public class InvoiceExtensionsTests new InvoiceLineItem { Description = "Custom Service", Quantity = 2, Amount = 2000 } }; - var invoice = new Invoice + var invoice = new Invoice + { + Lines = lineItems, + TotalTaxes = [new InvoiceTotalTax { Amount = 200 }] // Additional $2.00 tax + }; + var subscription = new Subscription(); + + // Act + var result = invoice.FormatForProvider(subscription); + + // Assert + Assert.Equal(5, result.Count); + Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); + Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[1]); + Assert.Equal("1 × Tax (at $8.00 / month)", result[2]); + Assert.Equal("Custom Service", result[3]); + Assert.Equal("1 × Tax (at $2.00 / month)", result[4]); + } + finally { - Lines = lineItems, - Tax = 200 // Additional $2.00 tax - }; - var subscription = new Subscription(); - - // Act - var result = invoice.FormatForProvider(subscription); - - // Assert - Assert.Equal(5, result.Count); - Assert.Equal("5 × Manage service provider (at $6.00 / month)", result[0]); - Assert.Equal("10 × Manage service provider (at $4.00 / month)", result[1]); - Assert.Equal("1 × Tax (at $8.00 / month)", result[2]); - Assert.Equal("Custom Service", result[3]); - Assert.Equal("1 × Tax (at $2.00 / month)", result[4]); + Thread.CurrentThread.CurrentCulture = originalCulture; + Thread.CurrentThread.CurrentUICulture = originalUICulture; + } } #endregion diff --git a/test/Core.Test/Billing/Mocks/MockPlans.cs b/test/Core.Test/Billing/Mocks/MockPlans.cs new file mode 100644 index 0000000000..b4737434fb --- /dev/null +++ b/test/Core.Test/Billing/Mocks/MockPlans.cs @@ -0,0 +1,37 @@ +using Bit.Core.Billing.Enums; +using Bit.Core.Models.StaticStore; +using Bit.Core.Test.Billing.Mocks.Plans; + +namespace Bit.Core.Test.Billing.Mocks; + +public class MockPlans +{ + public static List Plans => + [ + new CustomPlan(), + new Enterprise2019Plan(false), + new Enterprise2019Plan(true), + new Enterprise2020Plan(false), + new Enterprise2020Plan(true), + new Enterprise2023Plan(false), + new Enterprise2023Plan(true), + new EnterprisePlan(false), + new EnterprisePlan(true), + new Families2019Plan(), + new Families2025Plan(), + new FamiliesPlan(), + new FreePlan(), + new Teams2019Plan(false), + new Teams2019Plan(true), + new Teams2020Plan(false), + new Teams2020Plan(true), + new Teams2023Plan(false), + new Teams2023Plan(true), + new TeamsPlan(false), + new TeamsPlan(true), + new TeamsStarterPlan(), + new TeamsStarterPlan2023() + ]; + + public static Plan Get(PlanType planType) => Plans.SingleOrDefault(p => p.Type == planType)!; +} diff --git a/src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs b/test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs similarity index 89% rename from src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs index ce55cb422e..0105b7d07f 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/CustomPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/CustomPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record CustomPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs index b584647a26..27f3710b96 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Enterprise2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2019Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs index a1a6113cbc..8f56125fc1 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Enterprise2020Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Enterprise2020Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2020Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs rename to test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs index 8aeca521d1..563adc82a3 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record EnterprisePlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs index dce1719a49..f221821ed3 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/EnterprisePlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/EnterprisePlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Enterprise2023Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs similarity index 96% rename from src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs index 93ab2c39a1..a0257d88e9 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Families2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Families2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Families2019Plan : Plan { diff --git a/test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs new file mode 100644 index 0000000000..5f5424bbcf --- /dev/null +++ b/test/Core.Test/Billing/Mocks/Plans/Families2025Plan.cs @@ -0,0 +1,47 @@ +using Bit.Core.Billing.Enums; +using Bit.Core.Models.StaticStore; + +namespace Bit.Core.Test.Billing.Mocks.Plans; + +public record Families2025Plan : Plan +{ + public Families2025Plan() + { + Type = PlanType.FamiliesAnnually2025; + ProductTier = ProductTierType.Families; + Name = "Families 2025"; + IsAnnual = true; + NameLocalizationKey = "planNameFamilies"; + DescriptionLocalizationKey = "planDescFamilies"; + + TrialPeriodDays = 7; + + HasSelfHost = true; + HasTotp = true; + UsersGetPremium = true; + + UpgradeSortOrder = 1; + DisplaySortOrder = 1; + + PasswordManager = new Families2025PasswordManagerFeatures(); + } + + private record Families2025PasswordManagerFeatures : PasswordManagerPlanFeatures + { + public Families2025PasswordManagerFeatures() + { + BaseSeats = 6; + BaseStorageGb = 1; + MaxSeats = 6; + + HasAdditionalStorageOption = true; + + StripePlanId = "2020-families-org-annually"; + StripeStoragePlanId = "personal-storage-gb-annually"; + BasePrice = 40; + AdditionalStoragePricePerGb = 4; + + AllowSeatAutoscale = false; + } + } +} diff --git a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs b/test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs similarity index 80% rename from src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs index 8c71e50fa4..70aa613ee0 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/FamiliesPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record FamiliesPlan : Plan { @@ -23,12 +23,12 @@ public record FamiliesPlan : Plan UpgradeSortOrder = 1; DisplaySortOrder = 1; - PasswordManager = new TeamsPasswordManagerFeatures(); + PasswordManager = new FamiliesPasswordManagerFeatures(); } - private record TeamsPasswordManagerFeatures : PasswordManagerPlanFeatures + private record FamiliesPasswordManagerFeatures : PasswordManagerPlanFeatures { - public TeamsPasswordManagerFeatures() + public FamiliesPasswordManagerFeatures() { BaseSeats = 6; BaseStorageGb = 1; diff --git a/src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs b/test/Core.Test/Billing/Mocks/Plans/FreePlan.cs similarity index 95% rename from src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs rename to test/Core.Test/Billing/Mocks/Plans/FreePlan.cs index 3b0a8b7480..307f58c803 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FreePlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/FreePlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record FreePlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs index 27ed5e0bf4..f1aad7c16f 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Teams2019Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Teams2019Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2019Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs b/test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs rename to test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs index a760b9692e..546f1f84c5 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/Teams2020Plan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/Teams2020Plan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2020Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs index 654792ee0b..e0ecd35346 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs similarity index 98% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs index 8498af6b13..5ec2acd61c 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsPlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsPlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record Teams2023Plan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs similarity index 97% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs index d78844e429..119f431a56 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsStarterPlan : Plan { diff --git a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs similarity index 97% rename from src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs rename to test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs index ea15d9eb95..40952e75fb 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/TeamsStarterPlan2023.cs +++ b/test/Core.Test/Billing/Mocks/Plans/TeamsStarterPlan2023.cs @@ -1,7 +1,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.StaticStore; -namespace Bit.Core.Billing.Models.StaticStore.Plans; +namespace Bit.Core.Test.Billing.Mocks.Plans; public record TeamsStarterPlan2023 : Plan { diff --git a/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs b/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs index b2e94967ce..d1f02af50d 100644 --- a/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs +++ b/test/Core.Test/Billing/Models/Business/OrganizationLicenseTests.cs @@ -213,7 +213,8 @@ If you believe you need to change the version for a valid reason, please discuss LimitCollectionDeletion = true, AllowAdminAccessToAllCollectionItems = true, UseOrganizationDomains = true, - UseAdminSponsoredFamilies = false + UseAdminSponsoredFamilies = false, + UsePhishingBlocker = false, }; } @@ -227,8 +228,16 @@ If you believe you need to change the version for a valid reason, please discuss Status = "active", TrialStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), TrialEnd = new DateTime(2024, 2, 1, 0, 0, 0, DateTimeKind.Utc), - CurrentPeriodStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), - CurrentPeriodEnd = new DateTime(2024, 12, 31, 0, 0, 0, DateTimeKind.Utc) + Items = new StripeList + { + Data = [ + new SubscriptionItem + { + CurrentPeriodStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), + CurrentPeriodEnd = new DateTime(2024, 12, 31, 0, 0, 0, DateTimeKind.Utc) + } + ] + } }; return new SubscriptionInfo diff --git a/test/Core.Test/Billing/Models/Business/UserLicenseTests.cs b/test/Core.Test/Billing/Models/Business/UserLicenseTests.cs index 2d1e21b8c5..90bb619ab4 100644 --- a/test/Core.Test/Billing/Models/Business/UserLicenseTests.cs +++ b/test/Core.Test/Billing/Models/Business/UserLicenseTests.cs @@ -141,8 +141,16 @@ If you believe you need to change the version for a valid reason, please discuss Status = "active", TrialStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), TrialEnd = new DateTime(2024, 2, 1, 0, 0, 0, DateTimeKind.Utc), - CurrentPeriodStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), - CurrentPeriodEnd = new DateTime(2024, 12, 31, 0, 0, 0, DateTimeKind.Utc) + Items = new StripeList + { + Data = [ + new SubscriptionItem + { + CurrentPeriodStart = new DateTime(2024, 1, 1, 0, 0, 0, DateTimeKind.Utc), + CurrentPeriodEnd = new DateTime(2024, 12, 31, 0, 0, 0, DateTimeKind.Utc) + } + ] + } }; return new SubscriptionInfo diff --git a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs index 8e3cd5a0fa..2f278dcd20 100644 --- a/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/PreviewOrganizationTaxCommandTests.cs @@ -1,11 +1,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models.StaticStore.Plans; using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; @@ -54,11 +54,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 500, + TotalTaxes = [new InvoiceTotalTax { Amount = 500 }], Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -68,7 +68,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for sponsored subscription - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -77,7 +77,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2021-family-for-enterprise-annually" && options.SubscriptionDetails.Items[0].Quantity == 1 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -112,11 +112,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 750, + TotalTaxes = [new InvoiceTotalTax { Amount = 750 }], Total = 8250 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -126,7 +126,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(82.50m, total); // Verify the correct Stripe API call for standalone secrets manager - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -137,7 +137,9 @@ public class PreviewOrganizationTaxCommandTests item.Price == "2023-teams-org-seat-monthly" && item.Quantity == 5) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-teams-seat-monthly" && item.Quantity == 3) && - options.Coupon == CouponIDs.SecretsManagerStandalone)); + options.Discounts != null && + options.Discounts.Count == 1 && + options.Discounts[0].Coupon == CouponIDs.SecretsManagerStandalone)); } [Fact] @@ -173,11 +175,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 1200, + TotalTaxes = [new InvoiceTotalTax { Amount = 1200 }], Total = 12200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -187,7 +189,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(122.00m, total); // Verify the correct Stripe API call for comprehensive purchase with storage and service accounts - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -205,7 +207,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "secrets-manager-enterprise-seat-annually" && item.Quantity == 8) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-service-account-2024-annually" && item.Quantity == 3) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -234,11 +236,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 300, + TotalTaxes = [new InvoiceTotalTax { Amount = 300 }], Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -248,7 +250,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify the correct Stripe API call for Families tier (non-seat-based plan) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -257,7 +259,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2020-families-org-annually" && options.SubscriptionDetails.Items[0].Quantity == 6 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -286,11 +288,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 0, + TotalTaxes = [new InvoiceTotalTax { Amount = 0 }], Total = 2700 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -300,7 +302,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(27.00m, total); // Verify the correct Stripe API call for business use in non-US country (tax exempt reverse) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -309,7 +311,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-monthly" && options.SubscriptionDetails.Items[0].Quantity == 3 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -339,11 +341,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 2100, + TotalTaxes = [new InvoiceTotalTax { Amount = 2100 }], Total = 12100 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(purchase, billingAddress); @@ -353,7 +355,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(121.00m, total); // Verify the correct Stripe API call for Spanish NIF that adds both Spanish NIF and EU VAT tax IDs - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -365,7 +367,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-enterprise-seat-monthly" && options.SubscriptionDetails.Items[0].Quantity == 15 && - options.Coupon == null)); + options.Discounts == null)); } #endregion @@ -399,11 +401,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 120, + TotalTaxes = [new InvoiceTotalTax { Amount = 120 }], Total = 1320 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -413,7 +415,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(13.20m, total); // Verify the correct Stripe API call for free organization upgrade to Teams - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -422,7 +424,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-monthly" && options.SubscriptionDetails.Items[0].Quantity == 2 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -452,11 +454,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 400, + TotalTaxes = [new InvoiceTotalTax { Amount = 400 }], Total = 4400 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -466,7 +468,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(44.00m, total); // Verify the correct Stripe API call for free organization upgrade to Families (no SM for Families) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -474,8 +476,158 @@ public class PreviewOrganizationTaxCommandTests options.CustomerDetails.TaxExempt == TaxExempt.None && options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2020-families-org-annually" && - options.SubscriptionDetails.Items[0].Quantity == 2 && - options.Coupon == null)); + options.SubscriptionDetails.Items[0].Quantity == 1 && + options.Discounts == null)); + } + + [Fact] + public async Task Run_OrganizationPlanChange_FamiliesOrganizationToTeams_UsesOrganizationSeats() + { + var organization = new Organization + { + Id = Guid.NewGuid(), + PlanType = PlanType.FamiliesAnnually, + GatewayCustomerId = "cus_test123", + GatewaySubscriptionId = "sub_test123", + UseSecretsManager = false, + Seats = 6 + }; + + var planChange = new OrganizationSubscriptionPlanChange + { + Tier = ProductTierType.Teams, + Cadence = PlanCadenceType.Annually + }; + + var billingAddress = new BillingAddress + { + Country = "US", + PostalCode = "10012" + }; + + var currentPlan = new FamiliesPlan(); + var newPlan = new TeamsPlan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(currentPlan); + _pricingClient.GetPlanOrThrow(planChange.PlanType).Returns(newPlan); + + var subscriptionItems = new List + { + new() { Price = new Price { Id = "2020-families-org-annually" }, Quantity = 1 } + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Items = new StripeList { Data = subscriptionItems }, + Customer = new Customer { Discount = null } + }; + + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); + + var invoice = new Invoice + { + TotalTaxes = [new InvoiceTotalTax + { + Amount = 900 + } + ], + Total = 9900 + }; + + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); + + var result = await _command.Run(organization, planChange, billingAddress); + + Assert.True(result.IsT0); + var (tax, total) = result.AsT0; + Assert.Equal(9.00m, tax); + Assert.Equal(99.00m, total); + + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true && + options.Currency == "usd" && + options.CustomerDetails.Address.Country == "US" && + options.CustomerDetails.Address.PostalCode == "10012" && + options.CustomerDetails.TaxExempt == TaxExempt.None && + options.SubscriptionDetails.Items.Count == 1 && + options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-annually" && + options.SubscriptionDetails.Items[0].Quantity == 6 && + options.Discounts == null)); + } + + [Fact] + public async Task Run_OrganizationPlanChange_FamiliesOrganizationToEnterprise_UsesOrganizationSeats() + { + var organization = new Organization + { + Id = Guid.NewGuid(), + PlanType = PlanType.FamiliesAnnually, + GatewayCustomerId = "cus_test123", + GatewaySubscriptionId = "sub_test123", + UseSecretsManager = false, + Seats = 6 + }; + + var planChange = new OrganizationSubscriptionPlanChange + { + Tier = ProductTierType.Enterprise, + Cadence = PlanCadenceType.Annually + }; + + var billingAddress = new BillingAddress + { + Country = "US", + PostalCode = "10012" + }; + + var currentPlan = new FamiliesPlan(); + var newPlan = new EnterprisePlan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(currentPlan); + _pricingClient.GetPlanOrThrow(planChange.PlanType).Returns(newPlan); + + var subscriptionItems = new List + { + new() { Price = new Price { Id = "2020-families-org-annually" }, Quantity = 1 } + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Items = new StripeList { Data = subscriptionItems }, + Customer = new Customer { Discount = null } + }; + + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); + + var invoice = new Invoice + { + TotalTaxes = [new InvoiceTotalTax + { + Amount = 1200 + } + ], + Total = 13200 + }; + + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); + + var result = await _command.Run(organization, planChange, billingAddress); + + Assert.True(result.IsT0); + var (tax, total) = result.AsT0; + Assert.Equal(12.00m, tax); + Assert.Equal(132.00m, total); + + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => + options.AutomaticTax.Enabled == true && + options.Currency == "usd" && + options.CustomerDetails.Address.Country == "US" && + options.CustomerDetails.Address.PostalCode == "10012" && + options.CustomerDetails.TaxExempt == TaxExempt.None && + options.SubscriptionDetails.Items.Count == 1 && + options.SubscriptionDetails.Items[0].Price == "2023-enterprise-org-seat-annually" && + options.SubscriptionDetails.Items[0].Quantity == 6 && + options.Discounts == null)); } [Fact] @@ -505,11 +657,11 @@ public class PreviewOrganizationTaxCommandTests var invoice = new Invoice { - Tax = 800, + TotalTaxes = [new InvoiceTotalTax { Amount = 800 }], Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -519,7 +671,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for free organization with SM to Enterprise - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -530,7 +682,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "2023-enterprise-org-seat-annually" && item.Quantity == 2) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-enterprise-seat-annually" && item.Quantity == 2) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -578,15 +730,15 @@ public class PreviewOrganizationTaxCommandTests Customer = new Customer { Discount = null } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 1500, + TotalTaxes = [new InvoiceTotalTax { Amount = 1500 }], Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -596,7 +748,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for existing subscription upgrade - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -611,7 +763,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "secrets-manager-enterprise-seat-annually" && item.Quantity == 5) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-service-account-2024-annually" && item.Quantity == 10) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -662,15 +814,15 @@ public class PreviewOrganizationTaxCommandTests } }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 600, + TotalTaxes = [new InvoiceTotalTax { Amount = 600 }], Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, planChange, billingAddress); @@ -680,7 +832,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call preserves existing discount - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -689,7 +841,9 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-enterprise-org-seat-annually" && options.SubscriptionDetails.Items[0].Quantity == 5 && - options.Coupon == "EXISTING_DISCOUNT_50")); + options.Discounts != null && + options.Discounts.Count == 1 && + options.Discounts[0].Coupon == "EXISTING_DISCOUNT_50")); } [Fact] @@ -722,8 +876,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -765,15 +919,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 600, + TotalTaxes = [new InvoiceTotalTax { Amount = 600 }], Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -783,7 +937,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(66.00m, total); // Verify the correct Stripe API call for PM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -792,7 +946,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-monthly" && options.SubscriptionDetails.Items[0].Quantity == 10 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -830,15 +984,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 1200, + TotalTaxes = [new InvoiceTotalTax { Amount = 1200 }], Total = 13200 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -848,7 +1002,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(132.00m, total); // Verify the correct Stripe API call for PM seats + storage - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -859,7 +1013,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "2023-enterprise-org-seat-annually" && item.Quantity == 15) && options.SubscriptionDetails.Items.Any(item => item.Price == "storage-gb-annually" && item.Quantity == 5) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -897,15 +1051,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 800, + TotalTaxes = [new InvoiceTotalTax { Amount = 800 }], Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -915,7 +1069,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(88.00m, total); // Verify the correct Stripe API call for SM seats only - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -924,7 +1078,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "secrets-manager-teams-seat-annually" && options.SubscriptionDetails.Items[0].Quantity == 8 && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -956,10 +1110,7 @@ public class PreviewOrganizationTaxCommandTests Discount = null, TaxIds = new StripeList { - Data = new List - { - new() { Type = "gb_vat", Value = "GB123456789" } - } + Data = [new TaxId { Type = "gb_vat", Value = "GB123456789" }] } }; @@ -968,15 +1119,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 1500, + TotalTaxes = [new InvoiceTotalTax { Amount = 1500 }], Total = 16500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -986,7 +1137,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(165.00m, total); // Verify the correct Stripe API call for SM seats + service accounts with tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -1000,7 +1151,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "secrets-manager-enterprise-seat-monthly" && item.Quantity == 12) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-service-account-2024-monthly" && item.Quantity == 20) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -1040,10 +1191,7 @@ public class PreviewOrganizationTaxCommandTests }, TaxIds = new StripeList { - Data = new List - { - new() { Type = TaxIdType.SpanishNIF, Value = "12345678Z" } - } + Data = [new TaxId { Type = TaxIdType.SpanishNIF, Value = "12345678Z" }] } }; @@ -1052,15 +1200,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 2500, + TotalTaxes = [new InvoiceTotalTax { Amount = 2500 }], Total = 27500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1070,7 +1218,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(275.00m, total); // Verify the correct Stripe API call for comprehensive update with discount and Spanish tax ID - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "ES" && @@ -1088,7 +1236,9 @@ public class PreviewOrganizationTaxCommandTests item.Price == "secrets-manager-enterprise-seat-annually" && item.Quantity == 15) && options.SubscriptionDetails.Items.Any(item => item.Price == "secrets-manager-service-account-2024-annually" && item.Quantity == 30) && - options.Coupon == "ENTERPRISE_DISCOUNT_20")); + options.Discounts != null && + options.Discounts.Count == 1 && + options.Discounts[0].Coupon == "ENTERPRISE_DISCOUNT_20")); } [Fact] @@ -1126,15 +1276,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 500, + TotalTaxes = [new InvoiceTotalTax { Amount = 500 }], Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1144,7 +1294,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(55.00m, total); // Verify the correct Stripe API call for Families tier (personal usage, no business tax exemption) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -1155,7 +1305,7 @@ public class PreviewOrganizationTaxCommandTests item.Price == "2020-families-org-annually" && item.Quantity == 6) && options.SubscriptionDetails.Items.Any(item => item.Price == "personal-storage-gb-annually" && item.Quantity == 2) && - options.Coupon == null)); + options.Discounts == null)); } [Fact] @@ -1184,8 +1334,8 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal("Organization does not have a subscription.", badRequest.Response); // Verify no Stripe API calls were made - await _stripeAdapter.DidNotReceive().InvoiceCreatePreviewAsync(Arg.Any()); - await _stripeAdapter.DidNotReceive().SubscriptionGetAsync(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateInvoicePreviewAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().GetSubscriptionAsync(Arg.Any(), Arg.Any()); } [Fact] @@ -1228,15 +1378,15 @@ public class PreviewOrganizationTaxCommandTests Customer = customer }; - _stripeAdapter.SubscriptionGetAsync("sub_test123", Arg.Any()).Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_test123", Arg.Any()).Returns(subscription); var invoice = new Invoice { - Tax = 300, + TotalTaxes = [new InvoiceTotalTax { Amount = 300 }], Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(organization, update); @@ -1246,7 +1396,7 @@ public class PreviewOrganizationTaxCommandTests Assert.Equal(33.00m, total); // Verify only PM seats are included (storage=0 excluded, SM seats=0 so entire SM excluded) - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -1255,7 +1405,7 @@ public class PreviewOrganizationTaxCommandTests options.SubscriptionDetails.Items.Count == 1 && options.SubscriptionDetails.Items[0].Price == "2023-teams-org-seat-monthly" && options.SubscriptionDetails.Items[0].Quantity == 5 && - options.Coupon == null)); + options.Discounts == null)); } #endregion diff --git a/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs b/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs index 8570dfc6be..4cb4caae46 100644 --- a/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs +++ b/test/Core.Test/Billing/Organizations/Commands/UpdateOrganizationLicenseCommandTests.cs @@ -88,7 +88,7 @@ public class UpdateOrganizationLicenseCommandTests "Hash", "Signature", "SignatureBytes", "InstallationId", "Expires", "ExpirationWithoutGracePeriod", "Token", "LimitCollectionCreationDeletion", "LimitCollectionCreation", "LimitCollectionDeletion", "AllowAdminAccessToAllCollectionItems", - "UseOrganizationDomains", "UseAdminSponsoredFamilies") && + "UseOrganizationDomains", "UseAdminSponsoredFamilies", "UseAutomaticUserConfirmation", "UsePhishingBlocker") && // Same property but different name, use explicit mapping org.ExpirationDate == license.Expires)); } diff --git a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs index ed3698fb1d..0ceb257c88 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetCloudOrganizationLicenseQueryTests.cs @@ -8,7 +8,6 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Platform.Installations; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture; using Bit.Core.Test.Billing.AutoFixture; using Bit.Test.Common.AutoFixture; @@ -27,25 +26,27 @@ public class GetCloudOrganizationLicenseQueryTests { [Theory] [BitAutoData] - public async Task GetLicenseAsync_InvalidInstallationId_Throws(SutProvider sutProvider, + public async Task GetLicenseAsync_InvalidInstallationId_Throws( + SutProvider sutProvider, Organization organization, Guid installationId, int version) { sutProvider.GetDependency().GetByIdAsync(installationId).ReturnsNull(); - var exception = await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetLicenseAsync(organization, installationId, version)); + var exception = await Assert.ThrowsAsync(async () => + await sutProvider.Sut.GetLicenseAsync(organization, installationId, version)); Assert.Contains("Invalid installation id", exception.Message); } [Theory] [BitAutoData] - public async Task GetLicenseAsync_DisabledOrganization_Throws(SutProvider sutProvider, + public async Task GetLicenseAsync_DisabledOrganization_Throws( + SutProvider sutProvider, Organization organization, Guid installationId, Installation installation) { installation.Enabled = false; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - var exception = await Assert.ThrowsAsync( - async () => await sutProvider.Sut.GetLicenseAsync(organization, installationId)); + var exception = await Assert.ThrowsAsync(async () => + await sutProvider.Sut.GetLicenseAsync(organization, installationId)); Assert.Contains("Invalid installation id", exception.Message); } @@ -57,7 +58,7 @@ public class GetCloudOrganizationLicenseQueryTests { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); @@ -71,13 +72,14 @@ public class GetCloudOrganizationLicenseQueryTests [Theory] [BitAutoData] - public async Task GetLicenseAsync_WhenFeatureFlagEnabled_CreatesToken(SutProvider sutProvider, + public async Task GetLicenseAsync_WhenFeatureFlagEnabled_CreatesToken( + SutProvider sutProvider, Organization organization, Guid installationId, Installation installation, SubscriptionInfo subInfo, byte[] licenseSignature, string token) { installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); - sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(organization).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); sutProvider.GetDependency() .CreateOrganizationTokenAsync(organization, installationId, subInfo) @@ -90,7 +92,8 @@ public class GetCloudOrganizationLicenseQueryTests [Theory] [BitAutoData] - public async Task GetLicenseAsync_MSPManagedOrganization_UsesProviderSubscription(SutProvider sutProvider, + public async Task GetLicenseAsync_MSPManagedOrganization_UsesProviderSubscription( + SutProvider sutProvider, Organization organization, Guid installationId, Installation installation, SubscriptionInfo subInfo, byte[] licenseSignature, Provider provider) { @@ -99,14 +102,23 @@ public class GetCloudOrganizationLicenseQueryTests subInfo.Subscription = new SubscriptionInfo.BillingSubscription(new Subscription { - CurrentPeriodStart = DateTime.UtcNow, - CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodStart = DateTime.UtcNow, + CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) + } + ] + } }); installation.Enabled = true; sutProvider.GetDependency().GetByIdAsync(installationId).Returns(installation); sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); - sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); + sutProvider.GetDependency().GetSubscriptionAsync(provider).Returns(subInfo); sutProvider.GetDependency().SignLicense(Arg.Any()).Returns(licenseSignature); var result = await sutProvider.Sut.GetLicenseAsync(organization, installationId); diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs new file mode 100644 index 0000000000..e4cb0b0109 --- /dev/null +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationMetadataQueryTests.cs @@ -0,0 +1,360 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Bit.Core.Test.Billing.Mocks; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Organizations.Queries; + +[SutProviderCustomize] +public class GetOrganizationMetadataQueryTests +{ + [Theory, BitAutoData] + public async Task Run_SelfHosted_ReturnsDefault( + Organization organization, + SutProvider sutProvider) + { + sutProvider.GetDependency().SelfHosted.Returns(true); + + var result = await sutProvider.Sut.Run(organization); + + Assert.Equal(OrganizationMetadata.Default, result); + } + + [Theory, BitAutoData] + public async Task Run_NoGatewaySubscriptionId_ReturnsDefaultWithOccupiedSeats( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = null; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 10, Sponsored = 0 }); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(10, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_NullCustomer_ReturnsDefaultWithOccupiedSeats( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 5, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .ReturnsNull(); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(5, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_NullSubscription_ReturnsDefaultWithOccupiedSeats( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + + var customer = new Customer(); + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 7, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .Returns(customer); + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))) + .ReturnsNull(); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(7, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_WithSecretsManagerStandaloneCoupon_ReturnsMetadataWithFlag( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + organization.PlanType = PlanType.EnterpriseAnnually; + + var productId = "product_123"; + var customer = new Customer(); + + var subscription = new Subscription + { + Discounts = + [ + new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.SecretsManagerStandalone, + AppliesTo = new CouponAppliesTo + { + Products = [productId] + } + } + } + ], + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Plan = new Plan + { + ProductId = productId + } + } + ] + } + }; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 15, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .Returns(customer); + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))) + .Returns(subscription); + + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.True(result.IsOnSecretsManagerStandalone); + Assert.Equal(15, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_WithoutSecretsManagerStandaloneCoupon_ReturnsMetadataWithoutFlag( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + organization.PlanType = PlanType.TeamsAnnually; + + var customer = new Customer(); + + var subscription = new Subscription + { + Discounts = null, + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Plan = new Plan + { + ProductId = "product_123" + } + } + ] + } + }; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 20, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .Returns(customer); + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))) + .Returns(subscription); + + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(20, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_CouponDoesNotApplyToSubscriptionProducts_ReturnsFalseForStandaloneFlag( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + organization.PlanType = PlanType.EnterpriseAnnually; + + var customer = new Customer(); + + var subscription = new Subscription + { + Discounts = + [ + new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.SecretsManagerStandalone, + AppliesTo = new CouponAppliesTo + { + Products = ["different_product_id"] + } + } + } + ], + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Plan = new Plan + { + ProductId = "product_123" + } + } + ] + } + }; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 12, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .Returns(customer); + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))) + .Returns(subscription); + + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(12, result.OrganizationOccupiedSeats); + } + + [Theory, BitAutoData] + public async Task Run_PlanDoesNotSupportSecretsManager_ReturnsFalseForStandaloneFlag( + Organization organization, + SutProvider sutProvider) + { + organization.GatewaySubscriptionId = "sub_123"; + organization.PlanType = PlanType.FamiliesAnnually; + + var productId = "product_123"; + var customer = new Customer(); + + var subscription = new Subscription + { + Discounts = + [ + new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.SecretsManagerStandalone, + AppliesTo = new CouponAppliesTo + { + Products = [productId] + } + } + } + ], + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + Plan = new Plan + { + ProductId = productId + } + } + ] + } + }; + + sutProvider.GetDependency().SelfHosted.Returns(false); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 8, Sponsored = 0 }); + + sutProvider.GetDependency() + .GetCustomer(organization) + .Returns(customer); + + sutProvider.GetDependency() + .GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))) + .Returns(subscription); + + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + + var result = await sutProvider.Sut.Run(organization); + + Assert.NotNull(result); + Assert.False(result.IsOnSecretsManagerStandalone); + Assert.Equal(8, result.OrganizationOccupiedSeats); + } +} diff --git a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs index 5234d500d1..a7284410fe 100644 --- a/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs +++ b/test/Core.Test/Billing/Organizations/Queries/GetOrganizationWarningsQueryTests.cs @@ -2,13 +2,12 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Organizations.Queries; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -75,7 +74,7 @@ public class GetOrganizationWarningsQueryTests }); sutProvider.GetDependency().EditSubscription(organization.Id).Returns(true); - sutProvider.GetDependency().GetSetupIntentIdForSubscriber(organization.Id).Returns((string?)null); + sutProvider.GetDependency().Run(organization).Returns(false); var response = await sutProvider.Sut.Run(organization); @@ -86,12 +85,11 @@ public class GetOrganizationWarningsQueryTests } [Theory, BitAutoData] - public async Task Run_Has_FreeTrialWarning_WithUnverifiedBankAccount_NoWarning( + public async Task Run_Has_FreeTrialWarning_WithPaymentMethod_NoWarning( Organization organization, SutProvider sutProvider) { var now = DateTime.UtcNow; - const string setupIntentId = "setup_intent_id"; sutProvider.GetDependency() .GetSubscription(organization, Arg.Is(options => @@ -113,20 +111,7 @@ public class GetOrganizationWarningsQueryTests }); sutProvider.GetDependency().EditSubscription(organization.Id).Returns(true); - sutProvider.GetDependency().GetSetupIntentIdForSubscriber(organization.Id).Returns(setupIntentId); - sutProvider.GetDependency().SetupIntentGet(setupIntentId, Arg.Is( - options => options.Expand.Contains("payment_method"))).Returns(new SetupIntent - { - Status = "requires_action", - NextAction = new SetupIntentNextAction - { - VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() - }, - PaymentMethod = new PaymentMethod - { - UsBankAccount = new PaymentMethodUsBankAccount() - } - }); + sutProvider.GetDependency().Run(organization).Returns(true); var response = await sutProvider.Sut.Run(organization); @@ -286,7 +271,16 @@ public class GetOrganizationWarningsQueryTests CollectionMethod = CollectionMethod.SendInvoice, Customer = new Customer(), Status = SubscriptionStatus.Active, - CurrentPeriodEnd = now.AddDays(10), + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = now.AddDays(10) + } + ] + }, TestClock = new TestClock { FrozenTime = now @@ -387,7 +381,7 @@ public class GetOrganizationWarningsQueryTests var dueDate = now.AddDays(-10); - sutProvider.GetDependency().InvoiceSearchAsync(Arg.Is(options => + sutProvider.GetDependency().SearchInvoiceAsync(Arg.Is(options => options.Query == $"subscription:'{subscriptionId}' status:'open'")).Returns([ new Invoice { DueDate = dueDate, Created = dueDate.AddDays(-30) } ]); @@ -547,7 +541,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -588,7 +582,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -640,7 +634,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -692,7 +686,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -744,7 +738,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List @@ -790,7 +784,7 @@ public class GetOrganizationWarningsQueryTests .Returns(true); sutProvider.GetDependency() - .TaxRegistrationsListAsync(Arg.Any()) + .ListTaxRegistrationsAsync(Arg.Any()) .Returns(new StripeList { Data = new List diff --git a/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs index 800c3ec3ae..c933306399 100644 --- a/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Clients; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Entities; @@ -11,12 +12,18 @@ using Invoice = BitPayLight.Models.Invoice.Invoice; namespace Bit.Core.Test.Billing.Payment.Commands; +using static BitPayConstants; + public class CreateBitPayInvoiceForCreditCommandTests { private readonly IBitPayClient _bitPayClient = Substitute.For(); private readonly GlobalSettings _globalSettings = new() { - BitPay = new GlobalSettings.BitPaySettings { NotificationUrl = "https://example.com/bitpay/notification" } + BitPay = new GlobalSettings.BitPaySettings + { + NotificationUrl = "https://example.com/bitpay/notification", + WebhookKey = "test-webhook-key" + } }; private const string _redirectUrl = "https://bitwarden.com/redirect"; private readonly CreateBitPayInvoiceForCreditCommand _command; @@ -37,8 +44,8 @@ public class CreateBitPayInvoiceForCreditCommandTests _bitPayClient.CreateInvoice(Arg.Is(options => options.Buyer.Email == user.Email && options.Buyer.Name == user.Email && - options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && - options.PosData == $"userId:{user.Id},accountCredit:1" && + options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" && + options.PosData == $"userId:{user.Id},{PosDataKeys.AccountCredit}" && // ReSharper disable once CompareOfFloatsByEqualityOperator options.Price == Convert.ToDouble(10M) && options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); @@ -58,8 +65,8 @@ public class CreateBitPayInvoiceForCreditCommandTests _bitPayClient.CreateInvoice(Arg.Is(options => options.Buyer.Email == organization.BillingEmail && options.Buyer.Name == organization.Name && - options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && - options.PosData == $"organizationId:{organization.Id},accountCredit:1" && + options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" && + options.PosData == $"organizationId:{organization.Id},{PosDataKeys.AccountCredit}" && // ReSharper disable once CompareOfFloatsByEqualityOperator options.Price == Convert.ToDouble(10M) && options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); @@ -79,8 +86,8 @@ public class CreateBitPayInvoiceForCreditCommandTests _bitPayClient.CreateInvoice(Arg.Is(options => options.Buyer.Email == provider.BillingEmail && options.Buyer.Name == provider.Name && - options.NotificationUrl == _globalSettings.BitPay.NotificationUrl && - options.PosData == $"providerId:{provider.Id},accountCredit:1" && + options.NotificationUrl == $"{_globalSettings.BitPay.NotificationUrl}?key={_globalSettings.BitPay.WebhookKey}" && + options.PosData == $"providerId:{provider.Id},{PosDataKeys.AccountCredit}" && // ReSharper disable once CompareOfFloatsByEqualityOperator options.Price == Convert.ToDouble(10M) && options.RedirectUrl == _redirectUrl)).Returns(new Invoice { Url = "https://bitpay.com/invoice/123" }); diff --git a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs index c42049d5bb..5854d1c3b5 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdateBillingAddressCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Microsoft.Extensions.Logging; using NSubstitute; @@ -73,7 +72,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -84,7 +83,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -131,7 +130,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions") )).Returns(customer); @@ -144,7 +143,7 @@ public class UpdateBillingAddressCommandTests await _subscriberService.Received(1).CreateStripeCustomer(organization); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -192,7 +191,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -204,7 +203,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -260,7 +259,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.None @@ -272,10 +271,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdDeleteAsync(customer.Id, "tax_id_123"); + await _stripeAdapter.Received(1).DeleteTaxIdAsync(customer.Id, "tax_id_123"); } [Fact] @@ -322,7 +321,7 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse @@ -334,7 +333,7 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -384,14 +383,14 @@ public class UpdateBillingAddressCommandTests } }; - _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, Arg.Is(options => + _stripeAdapter.UpdateCustomerAsync(organization.GatewayCustomerId, Arg.Is(options => options.Address.Matches(input) && options.HasExpansions("subscriptions", "tax_ids") && options.TaxExempt == TaxExempt.Reverse )).Returns(customer); _stripeAdapter - .TaxIdCreateAsync(customer.Id, + .CreateTaxIdAsync(customer.Id, Arg.Is(options => options.Type == TaxIdType.EUVAT)) .Returns(new TaxId { Type = TaxIdType.EUVAT, Value = "ESA12345678" }); @@ -401,10 +400,10 @@ public class UpdateBillingAddressCommandTests var output = result.AsT0; Assert.Equivalent(input with { TaxId = new TaxID(TaxIdType.EUVAT, "ESA12345678") }, output); - await _stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, + await _stripeAdapter.Received(1).UpdateSubscriptionAsync(organization.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); - await _stripeAdapter.Received(1).TaxIdCreateAsync(organization.GatewayCustomerId, Arg.Is( + await _stripeAdapter.Received(1).CreateTaxIdAsync(organization.GatewayCustomerId, Arg.Is( options => options.Type == TaxIdType.SpanishNIF && options.Value == input.TaxId.Value)); } diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs index 72280c4c77..da42127f33 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -4,7 +4,6 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.Billing.Extensions; using Braintree; @@ -82,7 +81,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -144,7 +143,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -213,7 +212,7 @@ public class UpdatePaymentMethodCommandTests Status = "requires_action" }; - _stripeAdapter.SetupIntentList(Arg.Is(options => + _stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == token && options.HasExpansions("data.payment_method"))).Returns([setupIntent]); var result = await _command.Run(organization, @@ -232,7 +231,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("https://example.com", maskedBankAccount.HostedVerificationUrl); await _setupIntentCache.Received(1).Set(organization.Id, setupIntent.Id); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, Arg.Is(options => + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == string.Empty && options.Metadata[MetadataKeys.RetiredBraintreeCustomerId] == "braintree_customer_id")); } @@ -262,7 +261,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -291,7 +290,7 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); } @@ -315,7 +314,7 @@ public class UpdatePaymentMethodCommandTests const string token = "TOKEN"; _stripeAdapter - .PaymentMethodAttachAsync(token, + .AttachPaymentMethodAsync(token, Arg.Is(options => options.Customer == customer.Id)) .Returns(new PaymentMethod { @@ -344,10 +343,10 @@ public class UpdatePaymentMethodCommandTests Assert.Equal("9999", maskedCard.Last4); Assert.Equal("01/2028", maskedCard.Expiration); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.InvoiceSettings.DefaultPaymentMethod == token)); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Address.Country == "US" && options.Address.PostalCode == "12345")); } @@ -468,7 +467,7 @@ public class UpdatePaymentMethodCommandTests var maskedPayPalAccount = maskedPaymentMethod.AsT2; Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); - await _stripeAdapter.Received(1).CustomerUpdateAsync(customer.Id, + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id")); } diff --git a/test/Core.Test/Billing/Payment/Models/PaymentMethodTests.cs b/test/Core.Test/Billing/Payment/Models/PaymentMethodTests.cs new file mode 100644 index 0000000000..e3953cd152 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Models/PaymentMethodTests.cs @@ -0,0 +1,112 @@ +using System.Text.Json; +using Bit.Core.Billing.Payment.Models; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Models; + +public class PaymentMethodTests +{ + [Theory] + [InlineData("{\"cardNumber\":\"1234\"}")] + [InlineData("{\"type\":\"unknown_type\",\"data\":\"value\"}")] + [InlineData("{\"type\":\"invalid\",\"token\":\"test-token\"}")] + [InlineData("{\"type\":\"invalid\"}")] + public void Read_ShouldThrowJsonException_OnInvalidOrMissingType(string json) + { + // Arrange + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act & Assert + Assert.Throws(() => JsonSerializer.Deserialize(json, options)); + } + + [Theory] + [InlineData("{\"type\":\"card\"}")] + [InlineData("{\"type\":\"card\",\"token\":\"\"}")] + [InlineData("{\"type\":\"card\",\"token\":null}")] + public void Read_ShouldThrowJsonException_OnInvalidTokenizedPaymentMethodToken(string json) + { + // Arrange + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act & Assert + Assert.Throws(() => JsonSerializer.Deserialize(json, options)); + } + + // Tokenized payment method deserialization + [Theory] + [InlineData("bankAccount", TokenizablePaymentMethodType.BankAccount)] + [InlineData("card", TokenizablePaymentMethodType.Card)] + [InlineData("payPal", TokenizablePaymentMethodType.PayPal)] + public void Read_ShouldDeserializeTokenizedPaymentMethods(string typeString, TokenizablePaymentMethodType expectedType) + { + // Arrange + var json = $"{{\"type\":\"{typeString}\",\"token\":\"test-token\"}}"; + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act + var result = JsonSerializer.Deserialize(json, options); + + // Assert + Assert.True(result.IsTokenized); + Assert.Equal(expectedType, result.AsT0.Type); + Assert.Equal("test-token", result.AsT0.Token); + } + + // Non-tokenized payment method deserialization + [Theory] + [InlineData("accountcredit", NonTokenizablePaymentMethodType.AccountCredit)] + public void Read_ShouldDeserializeNonTokenizedPaymentMethods(string typeString, NonTokenizablePaymentMethodType expectedType) + { + // Arrange + var json = $"{{\"type\":\"{typeString}\"}}"; + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act + var result = JsonSerializer.Deserialize(json, options); + + // Assert + Assert.True(result.IsNonTokenized); + Assert.Equal(expectedType, result.AsT1.Type); + } + + // Tokenized payment method serialization + [Theory] + [InlineData(TokenizablePaymentMethodType.BankAccount, "bankaccount")] + [InlineData(TokenizablePaymentMethodType.Card, "card")] + [InlineData(TokenizablePaymentMethodType.PayPal, "paypal")] + public void Write_ShouldSerializeTokenizedPaymentMethods(TokenizablePaymentMethodType type, string expectedTypeString) + { + // Arrange + var paymentMethod = new PaymentMethod(new TokenizedPaymentMethod + { + Type = type, + Token = "test-token" + }); + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act + var json = JsonSerializer.Serialize(paymentMethod, options); + + // Assert + Assert.Contains($"\"type\":\"{expectedTypeString}\"", json); + Assert.Contains("\"token\":\"test-token\"", json); + } + + // Non-tokenized payment method serialization + [Theory] + [InlineData(NonTokenizablePaymentMethodType.AccountCredit, "accountcredit")] + public void Write_ShouldSerializeNonTokenizedPaymentMethods(NonTokenizablePaymentMethodType type, string expectedTypeString) + { + // Arrange + var paymentMethod = new PaymentMethod(new NonTokenizedPaymentMethod { Type = type }); + var options = new JsonSerializerOptions { Converters = { new PaymentMethodJsonConverter() } }; + + // Act + var json = JsonSerializer.Serialize(paymentMethod, options); + + // Assert + Assert.Contains($"\"type\":\"{expectedTypeString}\"", json); + Assert.DoesNotContain("token", json); + } +} diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs index b6b0d596b3..4e4c5199e2 100644 --- a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -3,7 +3,6 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; -using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Braintree; using Microsoft.Extensions.Logging; @@ -166,7 +165,7 @@ public class GetPaymentMethodQueryTests _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); _stripeAdapter - .SetupIntentGet("seti_123", + .GetSetupIntentAsync("seti_123", Arg.Is(options => options.HasExpansions("payment_method"))).Returns( new SetupIntent { diff --git a/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs new file mode 100644 index 0000000000..9ade4d0979 --- /dev/null +++ b/test/Core.Test/Billing/Payment/Queries/HasPaymentMethodQueryTests.cs @@ -0,0 +1,263 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Services; +using Bit.Core.Test.Billing.Extensions; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Billing.Payment.Queries; + +using static StripeConstants; + +public class HasPaymentMethodQueryTests +{ + private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly ISubscriberService _subscriberService = Substitute.For(); + private readonly HasPaymentMethodQuery _query; + + public HasPaymentMethodQueryTests() + { + _query = new HasPaymentMethodQuery( + _setupIntentCache, + _stripeAdapter, + _subscriberService); + } + + [Fact] + public async Task Run_NoCustomer_ReturnsFalse() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + _subscriberService.GetCustomer(organization).ReturnsNull(); + _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns((string)null); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.False(hasPaymentMethod); + } + + [Fact] + public async Task Run_NoCustomer_WithUnverifiedBankAccount_ReturnsTrue() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + _subscriberService.GetCustomer(organization).ReturnsNull(); + _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); + + _stripeAdapter + .GetSetupIntentAsync("seti_123", + Arg.Is(options => options.HasExpansions("payment_method"))) + .Returns(new SetupIntent + { + Status = "requires_action", + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + PaymentMethod = new PaymentMethod + { + UsBankAccount = new PaymentMethodUsBankAccount() + } + }); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.True(hasPaymentMethod); + } + + [Fact] + public async Task Run_NoPaymentMethod_ReturnsFalse() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.False(hasPaymentMethod); + } + + [Fact] + public async Task Run_HasDefaultPaymentMethodId_ReturnsTrue() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings + { + DefaultPaymentMethodId = "pm_123" + }, + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.True(hasPaymentMethod); + } + + [Fact] + public async Task Run_HasDefaultSourceId_ReturnsTrue() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + DefaultSourceId = "card_123", + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.True(hasPaymentMethod); + } + + [Fact] + public async Task Run_HasUnverifiedBankAccount_ReturnsTrue() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); + + _stripeAdapter + .GetSetupIntentAsync("seti_123", + Arg.Is(options => options.HasExpansions("payment_method"))) + .Returns(new SetupIntent + { + Status = "requires_action", + NextAction = new SetupIntentNextAction + { + VerifyWithMicrodeposits = new SetupIntentNextActionVerifyWithMicrodeposits() + }, + PaymentMethod = new PaymentMethod + { + UsBankAccount = new PaymentMethodUsBankAccount() + } + }); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.True(hasPaymentMethod); + } + + [Fact] + public async Task Run_HasBraintreeCustomerId_ReturnsTrue() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.True(hasPaymentMethod); + } + + [Fact] + public async Task Run_NoSetupIntentId_ReturnsFalse() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns((string)null); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.False(hasPaymentMethod); + } + + [Fact] + public async Task Run_SetupIntentNotBankAccount_ReturnsFalse() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary() + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + _setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id).Returns("seti_123"); + + _stripeAdapter + .GetSetupIntentAsync("seti_123", + Arg.Is(options => options.HasExpansions("payment_method"))) + .Returns(new SetupIntent + { + PaymentMethod = new PaymentMethod + { + Type = "card" + }, + Status = "succeeded" + }); + + var hasPaymentMethod = await _query.Run(organization); + + Assert.False(hasPaymentMethod); + } +} diff --git a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs index e808fb10b0..b58b5cd250 100644 --- a/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommandTests.cs @@ -1,6 +1,12 @@ -using Bit.Core.Billing.Caches; +using Bit.Core.Billing; +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Platform.Push; @@ -13,6 +19,8 @@ using NSubstitute; using Stripe; using Xunit; using Address = Stripe.Address; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; +using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; using StripeCustomer = Stripe.Customer; using StripeSubscription = Stripe.Subscription; @@ -27,6 +35,9 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests private readonly ISubscriberService _subscriberService = Substitute.For(); private readonly IUserService _userService = Substitute.For(); private readonly IPushNotificationService _pushNotificationService = Substitute.For(); + private readonly IPricingClient _pricingClient = Substitute.For(); + private readonly IHasPaymentMethodQuery _hasPaymentMethodQuery = Substitute.For(); + private readonly IUpdatePaymentMethodCommand _updatePaymentMethodCommand = Substitute.For(); private readonly CreatePremiumCloudHostedSubscriptionCommand _command; public CreatePremiumCloudHostedSubscriptionCommandTests() @@ -35,6 +46,17 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests baseServiceUri.CloudRegion.Returns("US"); _globalSettings.BaseServiceUri.Returns(baseServiceUri); + // Setup default premium plan with standard pricing + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } + }; + _pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan); + _command = new CreatePremiumCloudHostedSubscriptionCommand( _braintreeGateway, _globalSettings, @@ -43,7 +65,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests _subscriberService, _userService, _pushNotificationService, - Substitute.For>()); + Substitute.For>(), + _pricingClient, + _hasPaymentMethodQuery, + _updatePaymentMethodCommand); } [Theory, BitAutoData] @@ -105,17 +130,27 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); var mockSetupIntent = Substitute.For(); mockSetupIntent.Id = "seti_123"; - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); - _stripeAdapter.SetupIntentList(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()).Returns(Task.FromResult(new List { mockSetupIntent })); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -123,8 +158,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -152,13 +187,23 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -166,8 +211,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } @@ -198,10 +243,10 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); @@ -210,8 +255,8 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); - await _stripeAdapter.Received(1).CustomerCreateAsync(Arg.Any()); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); await _subscriberService.Received(1).CreateBraintreeCustomer(user, paymentMethod.Token); await _userService.Received(1).SaveUserAsync(user); await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); @@ -241,14 +286,23 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; - mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -266,7 +320,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests } [Theory, BitAutoData] - public async Task Run_UserHasExistingGatewayCustomerId_UsesExistingCustomer( + public async Task Run_UserHasExistingGatewayCustomerIdAndPaymentMethod_UsesExistingCustomer( User user, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) @@ -286,12 +340,24 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); + // Mock that the user has a payment method (this is the key difference from the credit purchase case) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(true); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); // Act var result = await _command.Run(user, paymentMethod, billingAddress, 0); @@ -299,7 +365,76 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); - await _stripeAdapter.DidNotReceive().CustomerCreateAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + await _updatePaymentMethodCommand.DidNotReceive().Run(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_UserPreviouslyPurchasedCreditWithoutPaymentMethod_UpdatesPaymentMethodAndCreatesSubscription( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = "existing_customer_123"; // Customer exists from previous credit purchase + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "existing_customer_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + var mockInvoice = Substitute.For(); + MaskedPaymentMethod mockMaskedPaymentMethod = new MaskedCard + { + Brand = "visa", + Last4 = "1234", + Expiration = "12/2025" + }; + + // Mock that the user does NOT have a payment method (simulating credit purchase scenario) + _hasPaymentMethodQuery.Run(Arg.Any()).Returns(false); + _updatePaymentMethodCommand.Run(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(mockMaskedPaymentMethod); + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + // Verify that update payment method was called (new behavior for credit purchase case) + await _updatePaymentMethodCommand.Received(1).Run(user, paymentMethod, billingAddress); + // Verify GetCustomerOrThrow was called after updating payment method + await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); + // Verify no new customer was created + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + // Verify subscription was created + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Any()); + // Verify user was updated correctly + Assert.True(user.Premium); + await _userService.Received(1).SaveUserAsync(user); + await _pushNotificationService.Received(1).PushSyncVaultAsync(user.Id); } [Theory, BitAutoData] @@ -326,14 +461,23 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "incomplete"; - mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -342,7 +486,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); Assert.True(user.Premium); - Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate); + Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate); } [Theory, BitAutoData] @@ -368,14 +512,23 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; - mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); // Act @@ -384,7 +537,7 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests // Assert Assert.True(result.IsT0); Assert.True(user.Premium); - Assert.Equal(mockSubscription.CurrentPeriodEnd, user.PremiumExpirationDate); + Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate); } [Theory, BitAutoData] @@ -411,14 +564,23 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "active"; // PayPal + active doesn't match pattern - mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.CreateBraintreeCustomer(Arg.Any(), Arg.Any()).Returns("bt_customer_123"); // Act @@ -453,17 +615,26 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var mockSubscription = Substitute.For(); mockSubscription.Id = "sub_123"; mockSubscription.Status = "incomplete"; - mockSubscription.CurrentPeriodEnd = DateTime.UtcNow.AddDays(30); + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; var mockInvoice = Substitute.For(); - _stripeAdapter.CustomerCreateAsync(Arg.Any()).Returns(mockCustomer); - _stripeAdapter.CustomerUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(mockSubscription); - _stripeAdapter.InvoiceUpdateAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); - _stripeAdapter.SetupIntentList(Arg.Any()) + _stripeAdapter.ListSetupIntentsAsync(Arg.Any()) .Returns(Task.FromResult(new List())); // Empty list - no setup intent found // Act @@ -474,4 +645,138 @@ public class CreatePremiumCloudHostedSubscriptionCommandTests var unhandled = result.AsT3; Assert.Equal("Something went wrong with your request. Please contact support for assistance.", unhandled.Response); } + + [Theory, BitAutoData] + public async Task Run_AccountCredit_WithExistingCustomer_Success( + User user, + NonTokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = "existing_customer_123"; + paymentMethod.Type = NonTokenizablePaymentMethodType.AccountCredit; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "existing_customer_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + var mockInvoice = Substitute.For(); + + _subscriberService.GetCustomerOrThrow(Arg.Any(), Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateInvoiceAsync(Arg.Any(), Arg.Any()).Returns(mockInvoice); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + // Assert + Assert.True(result.IsT0); + await _subscriberService.Received(1).GetCustomerOrThrow(Arg.Any(), Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + Assert.True(user.Premium); + Assert.Equal(mockSubscription.GetCurrentPeriodEnd(), user.PremiumExpirationDate); + } + + [Theory, BitAutoData] + public async Task Run_NonTokenizedPaymentWithoutExistingCustomer_ThrowsBillingException( + User user, + NonTokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + // No existing gateway customer ID + user.GatewayCustomerId = null; + paymentMethod.Type = NonTokenizablePaymentMethodType.AccountCredit; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, 0); + + //Assert + Assert.True(result.IsT3); // Assuming T3 is the Unhandled result + Assert.IsType(result.AsT3.Exception); + // Verify no customer was created or subscription attempted + await _stripeAdapter.DidNotReceive().CreateCustomerAsync(Arg.Any()); + await _stripeAdapter.DidNotReceive().CreateSubscriptionAsync(Arg.Any()); + await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task Run_WithAdditionalStorage_SetsCorrectMaxStorageGb( + User user, + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + user.GatewayCustomerId = null; + user.Email = "test@example.com"; + paymentMethod.Type = TokenizablePaymentMethodType.Card; + paymentMethod.Token = "card_token_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + const short additionalStorage = 2; + + // Setup premium plan with 5GB provided storage + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal, Provided = 1 } + }; + _pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan); + + var mockCustomer = Substitute.For(); + mockCustomer.Id = "cust_123"; + mockCustomer.Address = new Address { Country = "US", PostalCode = "12345" }; + mockCustomer.Metadata = new Dictionary(); + + var mockSubscription = Substitute.For(); + mockSubscription.Id = "sub_123"; + mockSubscription.Status = "active"; + mockSubscription.Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) + } + ] + }; + + _stripeAdapter.CreateCustomerAsync(Arg.Any()).Returns(mockCustomer); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(mockSubscription); + + // Act + var result = await _command.Run(user, paymentMethod, billingAddress, additionalStorage); + + // Assert + Assert.True(result.IsT0); + Assert.Equal((short)3, user.MaxStorageGb); // 1 (provided) + 2 (additional) = 3 + await _userService.Received(1).SaveUserAsync(user); + } + } diff --git a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs index bf7d093dc7..b5afaf65cd 100644 --- a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumTaxCommandTests.cs @@ -1,23 +1,38 @@ using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Premium.Commands; -using Bit.Core.Services; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; using static Bit.Core.Billing.Constants.StripeConstants; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; +using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; namespace Bit.Core.Test.Billing.Premium.Commands; public class PreviewPremiumTaxCommandTests { private readonly ILogger _logger = Substitute.For>(); + private readonly IPricingClient _pricingClient = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly PreviewPremiumTaxCommand _command; public PreviewPremiumTaxCommandTests() { - _command = new PreviewPremiumTaxCommand(_logger, _stripeAdapter); + // Setup default premium plan with standard pricing + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new PremiumPurchasable { Price = 10M, StripePriceId = Prices.PremiumAnnually }, + Storage = new PremiumPurchasable { Price = 4M, StripePriceId = Prices.StoragePlanPersonal } + }; + _pricingClient.GetAvailablePremiumPlan().Returns(premiumPlan); + + _command = new PreviewPremiumTaxCommand(_logger, _pricingClient, _stripeAdapter); } [Fact] @@ -31,11 +46,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 300, + TotalTaxes = [new InvoiceTotalTax { Amount = 300 }], Total = 3300 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -44,7 +59,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(3.00m, tax); Assert.Equal(33.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -65,11 +80,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 500, + TotalTaxes = [new InvoiceTotalTax { Amount = 500 }], Total = 5500 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(5, billingAddress); @@ -78,7 +93,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(5.00m, tax); Assert.Equal(55.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "CA" && @@ -101,11 +116,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 250, + TotalTaxes = [new InvoiceTotalTax { Amount = 250 }], Total = 2750 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -114,7 +129,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(2.50m, tax); Assert.Equal(27.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "GB" && @@ -135,11 +150,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 800, + TotalTaxes = [new InvoiceTotalTax { Amount = 800 }], Total = 8800 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(20, billingAddress); @@ -148,7 +163,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(8.00m, tax); Assert.Equal(88.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "DE" && @@ -171,11 +186,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 450, + TotalTaxes = [new InvoiceTotalTax { Amount = 450 }], Total = 4950 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(10, billingAddress); @@ -184,7 +199,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(4.50m, tax); Assert.Equal(49.50m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "AU" && @@ -207,11 +222,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 0, + TotalTaxes = [new InvoiceTotalTax { Amount = 0 }], Total = 3000 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); @@ -220,7 +235,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(0.00m, tax); Assert.Equal(30.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "US" && @@ -241,11 +256,11 @@ public class PreviewPremiumTaxCommandTests var invoice = new Invoice { - Tax = 600, + TotalTaxes = [new InvoiceTotalTax { Amount = 600 }], Total = 6600 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(-5, billingAddress); @@ -254,7 +269,7 @@ public class PreviewPremiumTaxCommandTests Assert.Equal(6.00m, tax); Assert.Equal(66.00m, total); - await _stripeAdapter.Received(1).InvoiceCreatePreviewAsync(Arg.Is(options => + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.Currency == "usd" && options.CustomerDetails.Address.Country == "FR" && @@ -276,11 +291,11 @@ public class PreviewPremiumTaxCommandTests // Stripe amounts are in cents var invoice = new Invoice { - Tax = 123, // $1.23 + TotalTaxes = [new InvoiceTotalTax { Amount = 123 }], // $1.23 Total = 3123 // $31.23 }; - _stripeAdapter.InvoiceCreatePreviewAsync(Arg.Any()).Returns(invoice); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()).Returns(invoice); var result = await _command.Run(0, billingAddress); diff --git a/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs b/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs new file mode 100644 index 0000000000..31547dffbe --- /dev/null +++ b/test/Core.Test/Billing/Premium/Queries/HasPremiumAccessQueryTests.cs @@ -0,0 +1,234 @@ +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Billing.Premium.Queries; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Billing.Premium.Queries; + +[SutProviderCustomize] +public class HasPremiumAccessQueryTests +{ + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasPersonalPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; + user.OrganizationPremium = false; + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumButHasOrgPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = true; // Has org premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserHasNoPersonalPremiumAndNoOrgPremium_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = false; + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(user.Id); + + // Assert + Assert.False(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_WhenUserNotFound_ThrowsNotFoundException( + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetPremiumAccessAsync(userId) + .Returns((UserPremiumAccess?)null); + + // Act & Assert + await Assert.ThrowsAsync( + () => sutProvider.Sut.HasPremiumAccessAsync(userId)); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasNoOrganizations_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; + user.OrganizationPremium = false; // No premium from anywhere + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.False(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasPremiumFromOrg_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = false; // No personal premium + user.OrganizationPremium = true; // But has premium from org + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.True(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasOnlyPersonalPremium_ReturnsFalse( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; // Has personal premium + user.OrganizationPremium = false; // Not in any org that grants premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.False(result); // Should return false because user is not in an org that grants premium + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserHasBothPersonalAndOrgPremium_ReturnsTrue( + UserPremiumAccess user, + SutProvider sutProvider) + { + // Arrange + user.PersonalPremium = true; // Has personal premium + user.OrganizationPremium = true; // Also in an org that grants premium + + sutProvider.GetDependency() + .GetPremiumAccessAsync(user.Id) + .Returns(user); + + // Act + var result = await sutProvider.Sut.HasPremiumFromOrganizationAsync(user.Id); + + // Assert + Assert.True(result); // Should return true because user IS in an org that grants premium (regardless of personal premium) + } + + [Theory, BitAutoData] + public async Task HasPremiumFromOrganizationAsync_WhenUserNotFound_ThrowsNotFoundException( + Guid userId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetPremiumAccessAsync(userId) + .Returns((UserPremiumAccess?)null); + + // Act & Assert + await Assert.ThrowsAsync( + () => sutProvider.Sut.HasPremiumFromOrganizationAsync(userId)); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_Bulk_WhenEmptyList_ReturnsEmptyDictionary( + SutProvider sutProvider) + { + // Arrange + var userIds = new List(); + + sutProvider.GetDependency() + .GetPremiumAccessByIdsAsync(userIds) + .Returns(new List()); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds); + + // Assert + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task HasPremiumAccessAsync_Bulk_ReturnsCorrectStatus( + UserPremiumAccess user1, + UserPremiumAccess user2, + UserPremiumAccess user3, + SutProvider sutProvider) + { + // Arrange + user1.PersonalPremium = true; + user1.OrganizationPremium = false; + user2.PersonalPremium = false; + user2.OrganizationPremium = false; + user3.PersonalPremium = false; + user3.OrganizationPremium = true; + + var users = new List { user1, user2, user3 }; + var userIds = users.Select(u => u.Id).ToList(); + + sutProvider.GetDependency() + .GetPremiumAccessByIdsAsync(Arg.Is>(ids => ids.SequenceEqual(userIds))) + .Returns(users); + + // Act + var result = await sutProvider.Sut.HasPremiumAccessAsync(userIds); + + // Assert + Assert.Equal(3, result.Count); + Assert.True(result[user1.Id]); // Personal premium + Assert.False(result[user2.Id]); // No premium + Assert.True(result[user3.Id]); // Organization premium + } +} diff --git a/test/Core.Test/Billing/Pricing/PricingClientTests.cs b/test/Core.Test/Billing/Pricing/PricingClientTests.cs new file mode 100644 index 0000000000..43329e9c2e --- /dev/null +++ b/test/Core.Test/Billing/Pricing/PricingClientTests.cs @@ -0,0 +1,474 @@ +using System.Net; +using Bit.Core.Billing; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using RichardSzalay.MockHttp; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.Billing.Pricing; + +[SutProviderCustomize] +public class PricingClientTests +{ + #region GetLookupKey Tests (via GetPlan) + + [Fact] + public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagEnabled_UsesFamilies2025LookupKey() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var planJson = CreatePlanJson("families-2025", "Families 2025", "families", 40M, "price_id"); + + mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/families-2025") + .Respond("application/json", planJson); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.NotNull(result); + Assert.Equal(PlanType.FamiliesAnnually2025, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagDisabled_UsesFamiliesLookupKey() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id"); + + mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/families") + .Respond("application/json", planJson); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.NotNull(result); + // PreProcessFamiliesPreMigrationPlan should change "families" to "families-2025" when FF is disabled + Assert.Equal(PlanType.FamiliesAnnually2025, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + #endregion + + #region PreProcessFamiliesPreMigrationPlan Tests (via GetPlan) + + [Fact] + public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagDisabled_ReturnsFamiliesAnnually2025PlanType() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + // billing-pricing returns "families" lookup key because the flag is off + var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id"); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.NotNull(result); + // PreProcessFamiliesPreMigrationPlan should convert the families lookup key to families-2025 + // and the PlanAdapter should assign the correct FamiliesAnnually2025 plan type + Assert.Equal(PlanType.FamiliesAnnually2025, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task GetPlan_WithFamiliesAnnually2025AndFeatureFlagEnabled_ReturnsFamiliesAnnually2025PlanType() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var planJson = CreatePlanJson("families-2025", "Families", "families", 40M, "price_id"); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.NotNull(result); + // PreProcessFamiliesPreMigrationPlan should ignore the lookup key because the flag is on + // and the PlanAdapter should assign the correct FamiliesAnnually2025 plan type + Assert.Equal(PlanType.FamiliesAnnually2025, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task GetPlan_WithFamiliesAnnuallyAndFeatureFlagEnabled_ReturnsFamiliesAnnuallyPlanType() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var planJson = CreatePlanJson("families", "Families", "families", 40M, "price_id"); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually); + + // Assert + Assert.NotNull(result); + // PreProcessFamiliesPreMigrationPlan should ignore the lookup key because the flag is on + // and the PlanAdapter should assign the correct FamiliesAnnually plan type + Assert.Equal(PlanType.FamiliesAnnually, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task GetPlan_WithOtherLookupKey_KeepsLookupKeyUnchanged() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var planJson = CreatePlanJson("enterprise-annually", "Enterprise", "enterprise", 144M, "price_id"); + + mockHttp.Expect(HttpMethod.Get, "https://test.com/plans/organization/enterprise-annually") + .Respond("application/json", planJson); + + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond("application/json", planJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.EnterpriseAnnually); + + // Assert + Assert.NotNull(result); + Assert.Equal(PlanType.EnterpriseAnnually, result.Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + #endregion + + #region ListPlans Tests + + [Fact] + public async Task ListPlans_WithFeatureFlagDisabled_ReturnsListWithPreProcessing() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + // biling-pricing would return "families" because the flag is disabled + var plansJson = $@"[ + {CreatePlanJson("families", "Families", "families", 40M, "price_id")}, + {CreatePlanJson("enterprise-annually", "Enterprise", "enterprise", 144M, "price_id")} + ]"; + + mockHttp.When(HttpMethod.Get, "*/plans/organization") + .Respond("application/json", plansJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(false); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.ListPlans(); + + // Assert + Assert.NotNull(result); + Assert.Equal(2, result.Count); + // First plan should have been preprocessed from "families" to "families-2025" + Assert.Equal(PlanType.FamiliesAnnually2025, result[0].Type); + // Second plan should remain unchanged + Assert.Equal(PlanType.EnterpriseAnnually, result[1].Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + [Fact] + public async Task ListPlans_WithFeatureFlagEnabled_ReturnsListWithoutPreProcessing() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + var plansJson = $@"[ + {CreatePlanJson("families", "Families", "families", 40M, "price_id")} + ]"; + + mockHttp.When(HttpMethod.Get, "*/plans/organization") + .Respond("application/json", plansJson); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.ListPlans(); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + // Plan should remain as FamiliesAnnually when FF is enabled + Assert.Equal(PlanType.FamiliesAnnually, result[0].Type); + mockHttp.VerifyNoOutstandingExpectation(); + } + + #endregion + + #region GetPlan - Additional Coverage + + [Theory, BitAutoData] + public async Task GetPlan_WhenSelfHosted_ReturnsNull( + SutProvider sutProvider) + { + // Arrange + var globalSettings = sutProvider.GetDependency(); + globalSettings.SelfHosted = true; + + // Act + var result = await sutProvider.Sut.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.Null(result); + } + + [Theory, BitAutoData] + public async Task GetPlan_WhenLookupKeyNotFound_ReturnsNull( + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency().SelfHosted = false; + + // Act - Using PlanType that doesn't have a lookup key mapping + var result = await sutProvider.Sut.GetPlan(unchecked((PlanType)999)); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task GetPlan_WhenPricingServiceReturnsNotFound_ReturnsNull() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond(HttpStatusCode.NotFound); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act + var result = await pricingClient.GetPlan(PlanType.FamiliesAnnually2025); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task GetPlan_WhenPricingServiceReturnsError_ThrowsBillingException() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + mockHttp.When(HttpMethod.Get, "*/plans/organization/*") + .Respond(HttpStatusCode.InternalServerError); + + var featureService = Substitute.For(); + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3).Returns(true); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act & Assert + await Assert.ThrowsAsync(() => + pricingClient.GetPlan(PlanType.FamiliesAnnually2025)); + } + + #endregion + + #region ListPlans - Additional Coverage + + [Theory, BitAutoData] + public async Task ListPlans_WhenSelfHosted_ReturnsEmptyList( + SutProvider sutProvider) + { + // Arrange + var globalSettings = sutProvider.GetDependency(); + globalSettings.SelfHosted = true; + + // Act + var result = await sutProvider.Sut.ListPlans(); + + // Assert + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public async Task ListPlans_WhenPricingServiceReturnsError_ThrowsBillingException() + { + // Arrange + var mockHttp = new MockHttpMessageHandler(); + mockHttp.When(HttpMethod.Get, "*/plans/organization") + .Respond(HttpStatusCode.InternalServerError); + + var featureService = Substitute.For(); + + var globalSettings = new GlobalSettings { SelfHosted = false }; + + var httpClient = new HttpClient(mockHttp) + { + BaseAddress = new Uri("https://test.com/") + }; + + var logger = Substitute.For>(); + var pricingClient = new PricingClient(featureService, globalSettings, httpClient, logger); + + // Act & Assert + await Assert.ThrowsAsync(() => + pricingClient.ListPlans()); + } + + #endregion + + private static string CreatePlanJson( + string lookupKey, + string name, + string tier, + decimal seatsPrice, + string seatsStripePriceId, + int seatsQuantity = 1) + { + return $@"{{ + ""lookupKey"": ""{lookupKey}"", + ""name"": ""{name}"", + ""tier"": ""{tier}"", + ""features"": [], + ""seats"": {{ + ""type"": ""packaged"", + ""quantity"": {seatsQuantity}, + ""price"": {seatsPrice}, + ""stripePriceId"": ""{seatsStripePriceId}"" + }}, + ""canUpgradeTo"": [], + ""additionalData"": {{ + ""nameLocalizationKey"": ""{lookupKey}Name"", + ""descriptionLocalizationKey"": ""{lookupKey}Description"" + }} + }}"; + } +} diff --git a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs index 7edc60a26a..f1b9446b6d 100644 --- a/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs +++ b/test/Core.Test/Billing/Services/OrganizationBillingServiceTests.cs @@ -1,11 +1,15 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -26,38 +30,39 @@ public class OrganizationBillingServiceTests SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().ListPlans().Returns(StaticStore.Plans.ToList()); + sutProvider.GetDependency().ListPlans().Returns(MockPlans.Plans.ToList()); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var subscriberService = sutProvider.GetDependency(); var organizationSeatCount = new OrganizationSeatCounts { Users = 1, Sponsored = 0 }; - var customer = new Customer - { - Discount = new Discount - { - Coupon = new Coupon - { - Id = StripeConstants.CouponIDs.SecretsManagerStandalone, - AppliesTo = new CouponAppliesTo - { - Products = ["product_id"] - } - } - } - }; + var customer = new Customer(); subscriberService - .GetCustomer(organization, Arg.Is(options => - options.Expand.Contains("discount.coupon.applies_to"))) + .GetCustomer(organization) .Returns(customer); - subscriberService.GetSubscription(organization).Returns(new Subscription - { - Items = new StripeList + subscriberService.GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))).Returns(new Subscription { - Data = + Discounts = + [ + new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.SecretsManagerStandalone, + AppliesTo = new CouponAppliesTo + { + Products = ["product_id"] + } + } + } + ], + Items = new StripeList + { + Data = [ new SubscriptionItem { @@ -67,8 +72,8 @@ public class OrganizationBillingServiceTests } } ] - } - }); + } + }); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) @@ -91,33 +96,429 @@ public class OrganizationBillingServiceTests { sutProvider.GetDependency().GetByIdAsync(organizationId).Returns(organization); - sutProvider.GetDependency().ListPlans().Returns(StaticStore.Plans.ToList()); + sutProvider.GetDependency().ListPlans().Returns(MockPlans.Plans.ToList()); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); + + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Users = 1, Sponsored = 0 }); var subscriberService = sutProvider.GetDependency(); // Set up subscriber service to return null for customer subscriberService - .GetCustomer(organization, Arg.Is(options => options.Expand.FirstOrDefault() == "discount.coupon.applies_to")) + .GetCustomer(organization) .Returns((Customer)null); // Set up subscriber service to return null for subscription - subscriberService.GetSubscription(organization).Returns((Subscription)null); + subscriberService.GetSubscription(organization, Arg.Is(options => + options.Expand.Contains("discounts.coupon.applies_to"))).Returns((Subscription)null); var metadata = await sutProvider.Sut.GetMetadata(organizationId); Assert.NotNull(metadata); Assert.False(metadata!.IsOnSecretsManagerStandalone); - Assert.False(metadata.HasSubscription); - Assert.False(metadata.IsSubscriptionUnpaid); - Assert.False(metadata.HasOpenInvoice); - Assert.False(metadata.IsSubscriptionCanceled); - Assert.Null(metadata.InvoiceDueDate); - Assert.Null(metadata.InvoiceCreatedDate); - Assert.Null(metadata.SubPeriodEndDate); + Assert.Equal(1, metadata.OrganizationOccupiedSeats); } #endregion + + #region Finalize - Trial Settings + + [Theory, BitAutoData] + public async Task NoPaymentMethodAndTrialPeriod_SetsMissingPaymentMethodCancelBehavior( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var plan = MockPlans.Get(PlanType.TeamsAnnually); + organization.PlanType = PlanType.TeamsAnnually; + organization.GatewayCustomerId = "cus_test123"; + organization.GatewaySubscriptionId = null; + + var subscriptionSetup = new SubscriptionSetup + { + PlanType = PlanType.TeamsAnnually, + PasswordManagerOptions = new SubscriptionSetup.PasswordManager + { + Seats = 5, + Storage = null, + PremiumAccess = false + }, + SecretsManagerOptions = null, + SkipTrial = false + }; + + var sale = new OrganizationSale + { + Organization = organization, + SubscriptionSetup = subscriptionSetup + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.TeamsAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .Run(organization) + .Returns(false); + + var customer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow(organization, Arg.Any()) + .Returns(customer); + + SubscriptionCreateOptions capturedOptions = null; + sutProvider.GetDependency() + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) + .Returns(new Subscription + { + Id = "sub_test123", + Status = StripeConstants.SubscriptionStatus.Trialing + }); + + sutProvider.GetDependency() + .ReplaceAsync(organization) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.Finalize(sale); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateSubscriptionAsync(Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.Equal(7, capturedOptions.TrialPeriodDays); + Assert.NotNull(capturedOptions.TrialSettings); + Assert.NotNull(capturedOptions.TrialSettings.EndBehavior); + Assert.Equal("cancel", capturedOptions.TrialSettings.EndBehavior.MissingPaymentMethod); + } + + [Theory, BitAutoData] + public async Task NoPaymentMethodButNoTrial_DoesNotSetMissingPaymentMethodBehavior( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var plan = MockPlans.Get(PlanType.TeamsAnnually); + organization.PlanType = PlanType.TeamsAnnually; + organization.GatewayCustomerId = "cus_test123"; + organization.GatewaySubscriptionId = null; + + var subscriptionSetup = new SubscriptionSetup + { + PlanType = PlanType.TeamsAnnually, + PasswordManagerOptions = new SubscriptionSetup.PasswordManager + { + Seats = 5, + Storage = null, + PremiumAccess = false + }, + SecretsManagerOptions = null, + SkipTrial = true // This will result in TrialPeriodDays = 0 + }; + + var sale = new OrganizationSale + { + Organization = organization, + SubscriptionSetup = subscriptionSetup + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.TeamsAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .Run(organization) + .Returns(false); + + var customer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow(organization, Arg.Any()) + .Returns(customer); + + SubscriptionCreateOptions capturedOptions = null; + sutProvider.GetDependency() + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) + .Returns(new Subscription + { + Id = "sub_test123", + Status = StripeConstants.SubscriptionStatus.Active + }); + + sutProvider.GetDependency() + .ReplaceAsync(organization) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.Finalize(sale); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateSubscriptionAsync(Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.Equal(0, capturedOptions.TrialPeriodDays); + Assert.Null(capturedOptions.TrialSettings); + } + + [Theory, BitAutoData] + public async Task HasPaymentMethodAndTrialPeriod_DoesNotSetMissingPaymentMethodBehavior( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var plan = MockPlans.Get(PlanType.TeamsAnnually); + organization.PlanType = PlanType.TeamsAnnually; + organization.GatewayCustomerId = "cus_test123"; + organization.GatewaySubscriptionId = null; + + var subscriptionSetup = new SubscriptionSetup + { + PlanType = PlanType.TeamsAnnually, + PasswordManagerOptions = new SubscriptionSetup.PasswordManager + { + Seats = 5, + Storage = null, + PremiumAccess = false + }, + SecretsManagerOptions = null, + SkipTrial = false + }; + + var sale = new OrganizationSale + { + Organization = organization, + SubscriptionSetup = subscriptionSetup + }; + + sutProvider.GetDependency() + .GetPlanOrThrow(PlanType.TeamsAnnually) + .Returns(plan); + + sutProvider.GetDependency() + .Run(organization) + .Returns(true); // Has payment method + + var customer = new Customer + { + Id = "cus_test123", + Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } + }; + + sutProvider.GetDependency() + .GetCustomerOrThrow(organization, Arg.Any()) + .Returns(customer); + + SubscriptionCreateOptions capturedOptions = null; + sutProvider.GetDependency() + .CreateSubscriptionAsync(Arg.Do(options => capturedOptions = options)) + .Returns(new Subscription + { + Id = "sub_test123", + Status = StripeConstants.SubscriptionStatus.Trialing + }); + + sutProvider.GetDependency() + .ReplaceAsync(organization) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.Finalize(sale); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateSubscriptionAsync(Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.Equal(7, capturedOptions.TrialPeriodDays); + Assert.Null(capturedOptions.TrialSettings); + } + + #endregion + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_UpdatesStripeCustomer( + Organization organization, + SutProvider sutProvider) + { + organization.Name = "Short name"; + + CustomerUpdateOptions capturedOptions = null; + sutProvider.GetDependency() + .UpdateCustomerAsync( + Arg.Is(id => id == organization.GatewayCustomerId), + Arg.Do(options => capturedOptions = options)) + .Returns(new Customer()); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.Equal(organization.BillingEmail, capturedOptions.Email); + Assert.Equal(organization.DisplayName(), capturedOptions.Description); + Assert.NotNull(capturedOptions.InvoiceSettings); + Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields); + Assert.Single(capturedOptions.InvoiceSettings.CustomFields); + + var customField = capturedOptions.InvoiceSettings.CustomFields.First(); + Assert.Equal(organization.SubscriberType(), customField.Name); + Assert.Equal(organization.DisplayName(), customField.Value); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsLong_UsesFullName( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var longName = "This is a very long organization name that exceeds thirty characters"; + organization.Name = longName; + + CustomerUpdateOptions capturedOptions = null; + sutProvider.GetDependency() + .UpdateCustomerAsync( + Arg.Is(id => id == organization.GatewayCustomerId), + Arg.Do(options => capturedOptions = options)) + .Returns(new Customer()); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Any()); + + Assert.NotNull(capturedOptions); + Assert.NotNull(capturedOptions.InvoiceSettings); + Assert.NotNull(capturedOptions.InvoiceSettings.CustomFields); + + var customField = capturedOptions.InvoiceSettings.CustomFields.First(); + Assert.Equal(longName, customField.Value); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsNull_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.GatewayCustomerId = null; + organization.Name = "Test Organization"; + organization.BillingEmail = "billing@example.com"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenGatewayCustomerIdIsEmpty_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.GatewayCustomerId = ""; + organization.Name = "Test Organization"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsNull_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = null; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenNameIsEmpty_LogsWarningAndReturns( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = ""; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.DidNotReceive().UpdateCustomerAsync( + Arg.Any(), + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateOrganizationNameAndEmail_WhenBillingEmailIsNull_UpdatesWithNull( + Organization organization, + SutProvider sutProvider) + { + // Arrange + organization.Name = "Test Organization"; + organization.BillingEmail = null; + organization.GatewayCustomerId = "cus_test123"; + var stripeAdapter = sutProvider.GetDependency(); + + // Act + await sutProvider.Sut.UpdateOrganizationNameAndEmail(organization); + + // Assert + await stripeAdapter.Received(1).UpdateCustomerAsync( + organization.GatewayCustomerId, + Arg.Is(options => + options.Email == null && + options.Description == organization.Name)); + } } diff --git a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs index 06a408c5a8..cd4c5effbe 100644 --- a/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs +++ b/test/Core.Test/Billing/Services/PaymentHistoryServiceTests.cs @@ -1,9 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Entities; using Bit.Core.Models.BitStripe; using Bit.Core.Repositories; -using Bit.Core.Services; using NSubstitute; using Stripe; using Xunit; @@ -19,7 +19,7 @@ public class PaymentHistoryServiceTests var subscriber = new Organization { GatewayCustomerId = "cus_id", GatewaySubscriptionId = "sub_id" }; var invoices = new List { new() { Id = "in_id" } }; var stripeAdapter = Substitute.For(); - stripeAdapter.InvoiceListAsync(Arg.Any()).Returns(invoices); + stripeAdapter.ListInvoicesAsync(Arg.Any()).Returns(invoices); var transactionRepository = Substitute.For(); var paymentHistoryService = new PaymentHistoryService(stripeAdapter, transactionRepository); @@ -29,7 +29,7 @@ public class PaymentHistoryServiceTests // Assert Assert.NotEmpty(result); Assert.Single(result); - await stripeAdapter.Received(1).InvoiceListAsync(Arg.Any()); + await stripeAdapter.Received(1).ListInvoicesAsync(Arg.Any()); } [Fact] diff --git a/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs new file mode 100644 index 0000000000..73f28113ca --- /dev/null +++ b/test/Core.Test/Billing/Services/StripePaymentServiceTests.cs @@ -0,0 +1,411 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Billing.Services.Implementations; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Services; + +[SutProviderCustomize] +public class StripePaymentServiceTests +{ + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithCustomerDiscount_ReturnsDiscountFromCustomer( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var customerDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 20m, + AmountOff = 1400 + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = customerDiscount + }, + Discounts = new List(), // Empty list + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(20m, result.CustomerDiscount.PercentOff); + Assert.Equal(14.00m, result.CustomerDiscount.AmountOff); // Converted from cents + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithoutCustomerDiscount_FallsBackToSubscriptionDiscounts( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscriptionDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 15m, + AmountOff = null + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + Discounts = new List { subscriptionDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should use subscription discount as fallback + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(15m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithBothDiscounts_PrefersCustomerDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var customerDiscount = new Discount + { + Coupon = new Coupon + { + Id = StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, + PercentOff = 25m + }, + End = null + }; + + var subscriptionDiscount = new Discount + { + Coupon = new Coupon + { + Id = "different-coupon-id", + PercentOff = 10m + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = customerDiscount // Should prefer this + }, + Discounts = new List { subscriptionDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should prefer customer discount over subscription discount + Assert.NotNull(result.CustomerDiscount); + Assert.Equal(StripeConstants.CouponIDs.Milestone2SubscriptionDiscount, result.CustomerDiscount.Id); + Assert.Equal(25m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNoDiscounts_ReturnsNullDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null + }, + Discounts = new List(), // Empty list, no discounts + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithMultipleSubscriptionDiscounts_SelectsFirstDiscount( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Multiple subscription-level discounts, no customer discount + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var firstDiscount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-10-percent", + PercentOff = 10m + }, + End = null + }; + + var secondDiscount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-20-percent", + PercentOff = 20m + }, + End = null + }; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + // Multiple subscription discounts - FirstOrDefault() should select the first one + Discounts = new List { firstDiscount, secondDiscount }, + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should select the first discount from the list (FirstOrDefault() behavior) + Assert.NotNull(result.CustomerDiscount); + Assert.Equal("coupon-10-percent", result.CustomerDiscount.Id); + Assert.Equal(10m, result.CustomerDiscount.PercentOff); + // Verify the second discount was not selected + Assert.NotEqual("coupon-20-percent", result.CustomerDiscount.Id); + Assert.NotEqual(20m, result.CustomerDiscount.PercentOff); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNullCustomer_HandlesGracefully( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Subscription with null Customer (defensive null check scenario) + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = null, // Customer not expanded or null + Discounts = new List(), // Empty discounts + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should handle null Customer gracefully without throwing NullReferenceException + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithNullDiscounts_HandlesGracefully( + SutProvider sutProvider, + User subscriber) + { + // Arrange - Subscription with null Discounts (defensive null check scenario) + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer + { + Discount = null // No customer discount + }, + Discounts = null, // Discounts not expanded or null + Items = new StripeList { Data = [] } + }; + + sutProvider.GetDependency() + .GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Any()) + .Returns(subscription); + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Should handle null Discounts gracefully without throwing NullReferenceException + Assert.Null(result.CustomerDiscount); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_VerifiesCorrectExpandOptions( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.Gateway = GatewayType.Stripe; + subscriber.GatewayCustomerId = "cus_test123"; + subscriber.GatewaySubscriptionId = "sub_test123"; + + var subscription = new Subscription + { + Id = "sub_test123", + Status = "active", + CollectionMethod = "charge_automatically", + Customer = new Customer { Discount = null }, + Discounts = new List(), // Empty list + Items = new StripeList { Data = [] } + }; + + var stripeAdapter = sutProvider.GetDependency(); + stripeAdapter + .GetSubscriptionAsync( + Arg.Any(), + Arg.Any()) + .Returns(subscription); + + // Act + await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert - Verify expand options are correct + await stripeAdapter.Received(1).GetSubscriptionAsync( + subscriber.GatewaySubscriptionId, + Arg.Is(o => + o.Expand.Contains("customer.discount.coupon.applies_to") && + o.Expand.Contains("discounts.coupon.applies_to") && + o.Expand.Contains("test_clock"))); + } + + [Theory] + [BitAutoData] + public async Task GetSubscriptionAsync_WithEmptyGatewaySubscriptionId_ReturnsEmptySubscriptionInfo( + SutProvider sutProvider, + User subscriber) + { + // Arrange + subscriber.GatewaySubscriptionId = null; + + // Act + var result = await sutProvider.Sut.GetSubscriptionAsync(subscriber); + + // Assert + Assert.NotNull(result); + Assert.Null(result.Subscription); + Assert.Null(result.CustomerDiscount); + Assert.Null(result.UpcomingInvoice); + + // Verify no Stripe API calls were made + await sutProvider.GetDependency() + .DidNotReceive() + .GetSubscriptionAsync(Arg.Any(), Arg.Any()); + } +} diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 2569ffff00..2f938065e5 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -3,10 +3,10 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; using Bit.Core.Billing.Tax.Models; using Bit.Core.Enums; -using Bit.Core.Services; using Bit.Core.Settings; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -44,7 +44,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); await ThrowsBillingExceptionAsync(() => @@ -52,11 +52,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -81,7 +81,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -95,12 +95,12 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is( + .UpdateSubscriptionAsync(subscriptionId, Arg.Is( options => options.Metadata["cancellingUserId"] == userId.ToString())); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -127,7 +127,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -141,11 +141,11 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await stripeAdapter .Received(1) - .SubscriptionCancelAsync(subscriptionId, Arg.Is(options => + .CancelSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason)); } @@ -170,7 +170,7 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var offboardingSurveyResponse = new OffboardingSurveyResponse @@ -184,7 +184,7 @@ public class SubscriberServiceTests await stripeAdapter .Received(1) - .SubscriptionUpdateAsync(subscriptionId, Arg.Is(options => + .UpdateSubscriptionAsync(subscriptionId, Arg.Is(options => options.CancelAtPeriodEnd == true && options.CancellationDetails.Comment == offboardingSurveyResponse.Feedback && options.CancellationDetails.Feedback == offboardingSurveyResponse.Reason && @@ -192,7 +192,7 @@ public class SubscriberServiceTests await stripeAdapter .DidNotReceiveWithAnyArgs() - .SubscriptionCancelAsync(Arg.Any(), Arg.Any()); + .CancelSubscriptionAsync(Arg.Any(), Arg.Any()); } #endregion @@ -223,7 +223,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -237,7 +237,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(); var customer = await sutProvider.Sut.GetCustomer(organization); @@ -253,7 +253,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomer(organization); @@ -287,7 +287,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetCustomerOrThrow(organization)); @@ -301,7 +301,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -318,7 +318,7 @@ public class SubscriberServiceTests var customer = new Customer(); sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId) + .GetCustomerAsync(organization.GatewayCustomerId) .Returns(customer); var gotCustomer = await sutProvider.Sut.GetCustomerOrThrow(organization); @@ -328,157 +328,6 @@ public class SubscriberServiceTests #endregion - #region GetPaymentMethod - - [Theory, BitAutoData] - public async Task GetPaymentMethod_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GetPaymentSource(null)); - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithNegativeStripeAccountBalance_ReturnsCorrectAccountCreditAmount(Organization organization, - SutProvider sutProvider) - { - // Arrange - // Stripe reports balance in cents as a negative number for credit - const int stripeAccountBalance = -593; // $5.93 credit (negative cents) - const decimal creditAmount = 5.93M; // Same value in dollars - - - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(creditAmount, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - - } - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithZeroStripeAccountBalance_ReturnsCorrectAccountCreditAmount( - Organization organization, SutProvider sutProvider) - { - // Arrange - const int stripeAccountBalance = 0; - - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(0, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - } - - [Theory, BitAutoData] - public async Task GetPaymentMethod_WithPositiveStripeAccountBalance_ReturnsCorrectAccountCreditAmount( - Organization organization, SutProvider sutProvider) - { - // Arrange - const int stripeAccountBalance = 593; // $5.93 charge balance - const decimal accountBalance = -5.93M; // account balance - var customer = new Customer - { - Balance = stripeAccountBalance, - Subscriptions = new StripeList() - { - Data = - [new Subscription { Id = organization.GatewaySubscriptionId, Status = "active" }] - }, - InvoiceSettings = new CustomerInvoiceSettings - { - DefaultPaymentMethod = new PaymentMethod - { - Type = StripeConstants.PaymentMethodTypes.USBankAccount, - UsBankAccount = new PaymentMethodUsBankAccount { BankName = "Chase", Last4 = "9999" } - } - } - }; - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, - Arg.Is(options => options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") - && options.Expand.Contains("subscriptions") - && options.Expand.Contains("tax_ids"))) - .Returns(customer); - - // Act - var result = await sutProvider.Sut.GetPaymentMethod(organization); - - // Assert - Assert.NotNull(result); - Assert.Equal(accountBalance, result.AccountCredit); - await sutProvider.GetDependency().Received(1).CustomerGetAsync( - organization.GatewayCustomerId, - Arg.Is(options => - options.Expand.Contains("default_source") && - options.Expand.Contains("invoice_settings.default_payment_method") && - options.Expand.Contains("subscriptions") && - options.Expand.Contains("tax_ids"))); - - } - #endregion - #region GetPaymentSource [Theory, BitAutoData] @@ -502,7 +351,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -539,7 +388,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -593,7 +442,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -629,7 +478,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -649,7 +498,7 @@ public class SubscriberServiceTests { var customer = new Customer { Id = provider.GatewayCustomerId }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -672,7 +521,7 @@ public class SubscriberServiceTests sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntent.Id); - sutProvider.GetDependency().SetupIntentGet(setupIntent.Id, + sutProvider.GetDependency().GetSetupIntentAsync(setupIntent.Id, Arg.Is(options => options.Expand.Contains("payment_method"))).Returns(setupIntent); var paymentMethod = await sutProvider.Sut.GetPaymentSource(provider); @@ -692,7 +541,7 @@ public class SubscriberServiceTests DefaultSource = new BankAccount { Status = "verified", BankName = "Chase", Last4 = "9999" } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -715,7 +564,7 @@ public class SubscriberServiceTests DefaultSource = new Card { Brand = "Visa", Last4 = "9999", ExpMonth = 9, ExpYear = 2028 } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.Expand.Contains("default_source") && options.Expand.Contains( "invoice_settings.default_payment_method"))) @@ -747,7 +596,7 @@ public class SubscriberServiceTests } }; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId, + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("default_source") && options.Expand.Contains("invoice_settings.default_payment_method"))) @@ -787,7 +636,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -801,7 +650,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(); var subscription = await sutProvider.Sut.GetSubscription(organization); @@ -817,7 +666,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscription(organization); @@ -849,7 +698,7 @@ public class SubscriberServiceTests SutProvider sutProvider) { sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ReturnsNull(); await ThrowsBillingExceptionAsync(async () => await sutProvider.Sut.GetSubscriptionOrThrow(organization)); @@ -863,7 +712,7 @@ public class SubscriberServiceTests var stripeException = new StripeException(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .ThrowsAsync(stripeException); await ThrowsBillingExceptionAsync( @@ -880,7 +729,7 @@ public class SubscriberServiceTests var subscription = new Subscription(); sutProvider.GetDependency() - .SubscriptionGetAsync(organization.GatewaySubscriptionId) + .GetSubscriptionAsync(organization.GatewaySubscriptionId) .Returns(subscription); var gotSubscription = await sutProvider.Sut.GetSubscriptionOrThrow(organization); @@ -889,65 +738,6 @@ public class SubscriberServiceTests } #endregion - #region GetTaxInformation - - [Theory, BitAutoData] - public async Task GetTaxInformation_NullSubscriber_ThrowsArgumentNullException( - SutProvider sutProvider) => - await Assert.ThrowsAsync(() => sutProvider.Sut.GetTaxInformation(null)); - - [Theory, BitAutoData] - public async Task GetTaxInformation_NullAddress_ReturnsNull( - Organization organization, - SutProvider sutProvider) - { - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(new Customer()); - - var taxInformation = await sutProvider.Sut.GetTaxInformation(organization); - - Assert.Null(taxInformation); - } - - [Theory, BitAutoData] - public async Task GetTaxInformation_Success( - Organization organization, - SutProvider sutProvider) - { - var address = new Address - { - Country = "US", - PostalCode = "12345", - Line1 = "123 Example St.", - Line2 = "Unit 1", - City = "Example Town", - State = "NY" - }; - - sutProvider.GetDependency().CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) - .Returns(new Customer - { - Address = address, - TaxIds = new StripeList - { - Data = [new TaxId { Value = "tax_id" }] - } - }); - - var taxInformation = await sutProvider.Sut.GetTaxInformation(organization); - - Assert.NotNull(taxInformation); - Assert.Equal(address.Country, taxInformation.Country); - Assert.Equal(address.PostalCode, taxInformation.PostalCode); - Assert.Equal("tax_id", taxInformation.TaxId); - Assert.Equal(address.Line1, taxInformation.Line1); - Assert.Equal(address.Line2, taxInformation.Line2); - Assert.Equal(address.City, taxInformation.City); - Assert.Equal(address.State, taxInformation.State); - } - - #endregion - #region RemovePaymentMethod [Theory, BitAutoData] public async Task RemovePaymentMethod_NullSubscriber_ThrowsArgumentNullException( @@ -970,7 +760,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (braintreeGateway, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1005,7 +795,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1042,7 +832,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1097,7 +887,7 @@ public class SubscriberServiceTests }; sutProvider.GetDependency() - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); var (_, customerGateway, paymentMethodGateway) = SetupBraintree(sutProvider.GetDependency()); @@ -1156,21 +946,21 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List())); await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.Received(1).BankAccountDeleteAsync(stripeCustomer.Id, bankAccountId); + await stripeAdapter.Received(1).DeleteBankAccountAsync(stripeCustomer.Id, bankAccountId); - await stripeAdapter.Received(1).CardDeleteAsync(stripeCustomer.Id, cardId); + await stripeAdapter.Received(1).DeleteCardAsync(stripeCustomer.Id, cardId); await stripeAdapter.DidNotReceiveWithAnyArgs() - .PaymentMethodDetachAsync(Arg.Any(), Arg.Any()); + .DetachPaymentMethodAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -1188,11 +978,11 @@ public class SubscriberServiceTests var stripeAdapter = sutProvider.GetDependency(); stripeAdapter - .CustomerGetAsync(organization.GatewayCustomerId, Arg.Any()) + .GetCustomerAsync(organization.GatewayCustomerId, Arg.Any()) .Returns(stripeCustomer); stripeAdapter - .PaymentMethodListAutoPagingAsync(Arg.Any()) + .ListPaymentMethodsAutoPagingAsync(Arg.Any()) .Returns(GetPaymentMethodsAsync(new List { new () @@ -1207,15 +997,15 @@ public class SubscriberServiceTests await sutProvider.Sut.RemovePaymentSource(organization); - await stripeAdapter.DidNotReceiveWithAnyArgs().BankAccountDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteBankAccountAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.DidNotReceiveWithAnyArgs().CardDeleteAsync(Arg.Any(), Arg.Any()); + await stripeAdapter.DidNotReceiveWithAnyArgs().DeleteCardAsync(Arg.Any(), Arg.Any()); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(bankAccountId); + .DetachPaymentMethodAsync(bankAccountId); await stripeAdapter.Received(1) - .PaymentMethodDetachAsync(cardId); + .DetachPaymentMethodAsync(cardId); } private static async IAsyncEnumerable GetPaymentMethodsAsync( @@ -1260,7 +1050,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1272,7 +1062,7 @@ public class SubscriberServiceTests Provider provider, SutProvider sutProvider) { - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); await ThrowsBillingExceptionAsync(() => @@ -1286,10 +1076,10 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId) + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer()); - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([new SetupIntent(), new SetupIntent()]); await ThrowsBillingExceptionAsync(() => @@ -1303,7 +1093,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1317,10 +1107,10 @@ public class SubscriberServiceTests var matchingSetupIntent = new SetupIntent { Id = "setup_intent_1" }; - stripeAdapter.SetupIntentList(Arg.Is(options => options.PaymentMethod == "TOKEN")) + stripeAdapter.ListSetupIntentsAsync(Arg.Is(options => options.PaymentMethod == "TOKEN")) .Returns([matchingSetupIntent]); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); @@ -1329,12 +1119,12 @@ public class SubscriberServiceTests await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_1"); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); } @@ -1345,7 +1135,7 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync( + stripeAdapter.GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids")) ) @@ -1358,22 +1148,22 @@ public class SubscriberServiceTests } }); - stripeAdapter.CustomerListPaymentMethods(provider.GatewayCustomerId).Returns([ + stripeAdapter.ListCustomerPaymentMethodsAsync(provider.GatewayCustomerId).Returns([ new PaymentMethod { Id = "payment_method_1" } ]); await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.Card, "TOKEN")); - await stripeAdapter.DidNotReceive().SetupIntentCancel(Arg.Any(), + await stripeAdapter.DidNotReceive().CancelSetupIntentAsync(Arg.Any(), Arg.Any()); - await stripeAdapter.Received(1).PaymentMethodDetachAsync("payment_method_1"); + await stripeAdapter.Received(1).DetachPaymentMethodAsync("payment_method_1"); - await stripeAdapter.Received(1).PaymentMethodAttachAsync("TOKEN", Arg.Is( + await stripeAdapter.Received(1).AttachPaymentMethodAsync("TOKEN", Arg.Is( options => options.Customer == provider.GatewayCustomerId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.InvoiceSettings.DefaultPaymentMethod == "TOKEN" && options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == null)); @@ -1386,7 +1176,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1412,7 +1202,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId, @@ -1450,7 +1240,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1504,7 +1294,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1573,7 +1363,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync(provider.GatewayCustomerId) + sutProvider.GetDependency().GetCustomerAsync(provider.GatewayCustomerId) .Returns(new Customer { Id = provider.GatewayCustomerId @@ -1605,7 +1395,7 @@ public class SubscriberServiceTests new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN"))); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerUpdateAsync(Arg.Any(), Arg.Any()); + .UpdateCustomerAsync(Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -1615,7 +1405,7 @@ public class SubscriberServiceTests { const string braintreeCustomerId = "braintree_customer_id"; - sutProvider.GetDependency().CustomerGetAsync( + sutProvider.GetDependency().GetCustomerAsync( provider.GatewayCustomerId, Arg.Is(p => p.Expand.Contains("tax") || p.Expand.Contains("tax_ids"))) .Returns(new Customer @@ -1652,7 +1442,7 @@ public class SubscriberServiceTests await sutProvider.Sut.UpdatePaymentSource(provider, new TokenizedPaymentSource(PaymentMethodType.PayPal, "TOKEN")); - await sutProvider.GetDependency().Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await sutProvider.GetDependency().Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Metadata[Core.Billing.Utilities.BraintreeCustomerIdKey] == braintreeCustomerId)); } @@ -1683,7 +1473,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1697,7 +1487,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "US" && @@ -1732,12 +1522,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1746,13 +1536,13 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } @@ -1765,7 +1555,7 @@ public class SubscriberServiceTests var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; - stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( + stripeAdapter.GetCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); var taxInformation = new TaxInformation( @@ -1779,7 +1569,7 @@ public class SubscriberServiceTests "NY"); sutProvider.GetDependency() - .CustomerUpdateAsync( + .UpdateCustomerAsync( Arg.Is(p => p == provider.GatewayCustomerId), Arg.Is(options => options.Address.Country == "CA" && @@ -1815,12 +1605,12 @@ public class SubscriberServiceTests }); var subscription = new Subscription { Items = new StripeList() }; - sutProvider.GetDependency().SubscriptionGetAsync(Arg.Any()) + sutProvider.GetDependency().GetSubscriptionAsync(Arg.Any()) .Returns(subscription); await sutProvider.Sut.UpdateTaxInformation(provider, taxInformation); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is( options => options.Address.Country == taxInformation.Country && options.Address.PostalCode == taxInformation.PostalCode && @@ -1829,63 +1619,21 @@ public class SubscriberServiceTests options.Address.City == taxInformation.City && options.Address.State == taxInformation.State)); - await stripeAdapter.Received(1).TaxIdDeleteAsync(provider.GatewayCustomerId, "tax_id_1"); + await stripeAdapter.Received(1).DeleteTaxIdAsync(provider.GatewayCustomerId, "tax_id_1"); - await stripeAdapter.Received(1).TaxIdCreateAsync(provider.GatewayCustomerId, Arg.Is( + await stripeAdapter.Received(1).CreateTaxIdAsync(provider.GatewayCustomerId, Arg.Is( options => options.Type == "us_ein" && options.Value == taxInformation.TaxId)); - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, + await stripeAdapter.Received(1).UpdateCustomerAsync(provider.GatewayCustomerId, Arg.Is(options => options.TaxExempt == StripeConstants.TaxExempt.Reverse)); - await stripeAdapter.Received(1).SubscriptionUpdateAsync(provider.GatewaySubscriptionId, + await stripeAdapter.Received(1).UpdateSubscriptionAsync(provider.GatewaySubscriptionId, Arg.Is(options => options.AutomaticTax.Enabled == true)); } #endregion - #region VerifyBankAccount - - [Theory, BitAutoData] - public async Task VerifyBankAccount_NoSetupIntentId_ThrowsBillingException( - Provider provider, - SutProvider sutProvider) => await ThrowsBillingExceptionAsync(() => sutProvider.Sut.VerifyBankAccount(provider, "")); - - [Theory, BitAutoData] - public async Task VerifyBankAccount_MakesCorrectInvocations( - Provider provider, - SutProvider sutProvider) - { - const string descriptorCode = "SM1234"; - - var setupIntent = new SetupIntent - { - Id = "setup_intent_id", - PaymentMethodId = "payment_method_id" - }; - - sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns(setupIntent.Id); - - var stripeAdapter = sutProvider.GetDependency(); - - stripeAdapter.SetupIntentGet(setupIntent.Id).Returns(setupIntent); - - await sutProvider.Sut.VerifyBankAccount(provider, descriptorCode); - - await stripeAdapter.Received(1).SetupIntentVerifyMicroDeposit(setupIntent.Id, - Arg.Is( - options => options.DescriptorCode == descriptorCode)); - - await stripeAdapter.Received(1).PaymentMethodAttachAsync(setupIntent.PaymentMethodId, - Arg.Is( - options => options.Customer == provider.GatewayCustomerId)); - - await stripeAdapter.Received(1).CustomerUpdateAsync(provider.GatewayCustomerId, Arg.Is( - options => options.InvoiceSettings.DefaultPaymentMethod == setupIntent.PaymentMethodId)); - } - - #endregion - #region IsValidGatewayCustomerIdAsync [Theory, BitAutoData] @@ -1907,7 +1655,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1921,7 +1669,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any()); + .GetCustomerAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1930,12 +1678,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Returns(new Customer()); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Returns(new Customer()); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } [Theory, BitAutoData] @@ -1945,12 +1693,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId).Throws(stripeException); + stripeAdapter.GetCustomerAsync(organization.GatewayCustomerId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewayCustomerIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).CustomerGetAsync(organization.GatewayCustomerId); + await stripeAdapter.Received(1).GetCustomerAsync(organization.GatewayCustomerId); } #endregion @@ -1976,7 +1724,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1990,7 +1738,7 @@ public class SubscriberServiceTests Assert.True(result); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SubscriptionGetAsync(Arg.Any()); + .GetSubscriptionAsync(Arg.Any()); } [Theory, BitAutoData] @@ -1999,12 +1747,12 @@ public class SubscriberServiceTests SutProvider sutProvider) { var stripeAdapter = sutProvider.GetDependency(); - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Returns(new Subscription()); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.True(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } [Theory, BitAutoData] @@ -2014,12 +1762,12 @@ public class SubscriberServiceTests { var stripeAdapter = sutProvider.GetDependency(); var stripeException = new StripeException { StripeError = new StripeError { Code = "resource_missing" } }; - stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId).Throws(stripeException); + stripeAdapter.GetSubscriptionAsync(organization.GatewaySubscriptionId).Throws(stripeException); var result = await sutProvider.Sut.IsValidGatewaySubscriptionIdAsync(organization); Assert.False(result); - await stripeAdapter.Received(1).SubscriptionGetAsync(organization.GatewaySubscriptionId); + await stripeAdapter.Received(1).GetSubscriptionAsync(organization.GatewaySubscriptionId); } #endregion diff --git a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs index a5970c79ab..9f34c37b3c 100644 --- a/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs +++ b/test/Core.Test/Billing/Subscriptions/RestartSubscriptionCommandTests.cs @@ -1,12 +1,14 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using NSubstitute; using Stripe; using Xunit; @@ -18,20 +20,19 @@ using static StripeConstants; public class RestartSubscriptionCommandTests { private readonly IOrganizationRepository _organizationRepository = Substitute.For(); - private readonly IProviderRepository _providerRepository = Substitute.For(); + private readonly IPricingClient _pricingClient = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly ISubscriberService _subscriberService = Substitute.For(); - private readonly IUserRepository _userRepository = Substitute.For(); private readonly RestartSubscriptionCommand _command; public RestartSubscriptionCommandTests() { _command = new RestartSubscriptionCommand( + Substitute.For>(), _organizationRepository, - _providerRepository, + _pricingClient, _stripeAdapter, - _subscriberService, - _userRepository); + _subscriberService); } [Fact] @@ -64,11 +65,56 @@ public class RestartSubscriptionCommandTests } [Fact] - public async Task Run_Organization_Success_ReturnsNone() + public async Task Run_Provider_ReturnsUnhandledWithNotSupportedException() + { + var provider = new Provider { Id = Guid.NewGuid() }; + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_123" + }; + + _subscriberService.GetSubscription(provider).Returns(existingSubscription); + + var result = await _command.Run(provider); + + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.IsType(unhandled.Exception); + } + + [Fact] + public async Task Run_User_ReturnsUnhandledWithNotSupportedException() + { + var user = new User { Id = Guid.NewGuid() }; + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_123" + }; + + _subscriberService.GetSubscription(user).Returns(existingSubscription); + + var result = await _command.Run(user); + + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.IsType(unhandled.Exception); + } + + [Fact] + public async Task Run_Organization_MissingPasswordManagerItem_ReturnsUnhandledWithConflictException() { var organizationId = Guid.NewGuid(); - var organization = new Organization { Id = organizationId }; - var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseAnnually + }; + + var plan = MockPlans.Get(PlanType.EnterpriseAnnually); var existingSubscription = new Subscription { @@ -78,36 +124,149 @@ public class RestartSubscriptionCommandTests { Data = [ - new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 }, - new SubscriptionItem { Price = new Price { Id = "price_2" }, Quantity = 2 } + new SubscriptionItem { Price = new Price { Id = "some-other-price-id" }, Quantity = 10 } ] }, - Metadata = new Dictionary { ["key"] = "value" } + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([plan]); + + var result = await _command.Run(organization); + + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.IsType(unhandled.Exception); + Assert.Equal("Organization's subscription does not have a Password Manager subscription item.", unhandled.Exception.Message); + } + + [Fact] + public async Task Run_Organization_PlanNotFound_ReturnsUnhandledWithConflictException() + { + var organizationId = Guid.NewGuid(); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseAnnually + }; + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = "some-price-id" }, Quantity = 10 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + // Return a plan list that doesn't contain the organization's plan type + _pricingClient.ListPlans().Returns([MockPlans.Get(PlanType.TeamsAnnually)]); + + var result = await _command.Run(organization); + + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.IsType(unhandled.Exception); + Assert.Equal("Could not find plan for organization's plan type", unhandled.Exception.Message); + } + + [Fact] + public async Task Run_Organization_DisabledPlanWithNoEnabledReplacement_ReturnsUnhandledWithConflictException() + { + var organizationId = Guid.NewGuid(); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseAnnually2023 + }; + + var oldPlan = new DisabledEnterprisePlan2023(true); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_old", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = oldPlan.PasswordManager.StripeSeatPlanId }, Quantity = 20 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + // Return only the disabled plan, with no enabled replacement + _pricingClient.ListPlans().Returns([oldPlan]); + + var result = await _command.Run(organization); + + Assert.True(result.IsT3); + var unhandled = result.AsT3; + Assert.IsType(unhandled.Exception); + Assert.Equal("Could not find the current, enabled plan for organization's tier and cadence", unhandled.Exception.Message); + } + + [Fact] + public async Task Run_Organization_WithNonDisabledPlan_PasswordManagerOnly_Success() + { + var organizationId = Guid.NewGuid(); + var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseAnnually + }; + + var plan = MockPlans.Get(PlanType.EnterpriseAnnually); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, Quantity = 10 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } }; var newSubscription = new Subscription { Id = "sub_new", - CurrentPeriodEnd = currentPeriodEnd + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } }; _subscriberService.GetSubscription(organization).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _pricingClient.ListPlans().Returns([plan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); var result = await _command.Run(organization); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Is((SubscriptionCreateOptions options) => + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => options.AutomaticTax.Enabled == true && options.CollectionMethod == CollectionMethod.ChargeAutomatically && options.Customer == "cus_123" && - options.Items.Count == 2 && - options.Items[0].Price == "price_1" && - options.Items[0].Quantity == 1 && - options.Items[1].Price == "price_2" && - options.Items[1].Quantity == 2 && - options.Metadata["key"] == "value" && + options.Items.Count == 1 && + options.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 10 && + options.Metadata["organizationId"] == organizationId.ToString() && options.OffSession == true && options.TrialPeriodDays == 0)); @@ -115,84 +274,417 @@ public class RestartSubscriptionCommandTests org.Id == organizationId && org.GatewaySubscriptionId == "sub_new" && org.Enabled == true && - org.ExpirationDate == currentPeriodEnd)); + org.ExpirationDate == currentPeriodEnd && + org.PlanType == PlanType.EnterpriseAnnually)); } [Fact] - public async Task Run_Provider_Success_ReturnsNone() + public async Task Run_Organization_WithNonDisabledPlan_WithStorage_Success() { - var providerId = Guid.NewGuid(); - var provider = new Provider { Id = providerId }; - - var existingSubscription = new Subscription - { - Status = SubscriptionStatus.Canceled, - CustomerId = "cus_123", - Items = new StripeList - { - Data = [new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 }] - }, - Metadata = new Dictionary() - }; - - var newSubscription = new Subscription - { - Id = "sub_new", - CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) - }; - - _subscriberService.GetSubscription(provider).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); - - var result = await _command.Run(provider); - - Assert.True(result.IsT0); - - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); - - await _providerRepository.Received(1).ReplaceAsync(Arg.Is(prov => - prov.Id == providerId && - prov.GatewaySubscriptionId == "sub_new" && - prov.Enabled == true)); - } - - [Fact] - public async Task Run_User_Success_ReturnsNone() - { - var userId = Guid.NewGuid(); - var user = new User { Id = userId }; + var organizationId = Guid.NewGuid(); var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.TeamsAnnually + }; + + var plan = MockPlans.Get(PlanType.TeamsAnnually); var existingSubscription = new Subscription { Status = SubscriptionStatus.Canceled, - CustomerId = "cus_123", + CustomerId = "cus_456", Items = new StripeList { - Data = [new SubscriptionItem { Price = new Price { Id = "price_1" }, Quantity = 1 }] + Data = + [ + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, Quantity = 5 }, + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeStoragePlanId }, Quantity = 3 } + ] }, - Metadata = new Dictionary() + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } }; var newSubscription = new Subscription { - Id = "sub_new", - CurrentPeriodEnd = currentPeriodEnd + Id = "sub_new_2", + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } }; - _subscriberService.GetSubscription(user).Returns(existingSubscription); - _stripeAdapter.SubscriptionCreateAsync(Arg.Any()).Returns(newSubscription); + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([plan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); - var result = await _command.Run(user); + var result = await _command.Run(organization); Assert.True(result.IsT0); - await _stripeAdapter.Received(1).SubscriptionCreateAsync(Arg.Any()); + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => + options.Items.Count == 2 && + options.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 5 && + options.Items[1].Price == plan.PasswordManager.StripeStoragePlanId && + options.Items[1].Quantity == 3)); - await _userRepository.Received(1).ReplaceAsync(Arg.Is(u => - u.Id == userId && - u.GatewaySubscriptionId == "sub_new" && - u.Premium == true && - u.PremiumExpirationDate == currentPeriodEnd)); + await _organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => + org.Id == organizationId && + org.GatewaySubscriptionId == "sub_new_2" && + org.Enabled == true)); + } + + [Fact] + public async Task Run_Organization_WithSecretsManager_Success() + { + var organizationId = Guid.NewGuid(); + var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseMonthly + }; + + var plan = MockPlans.Get(PlanType.EnterpriseMonthly); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_789", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, Quantity = 15 }, + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeStoragePlanId }, Quantity = 2 }, + new SubscriptionItem { Price = new Price { Id = plan.SecretsManager.StripeSeatPlanId }, Quantity = 10 }, + new SubscriptionItem { Price = new Price { Id = plan.SecretsManager.StripeServiceAccountPlanId }, Quantity = 100 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + var newSubscription = new Subscription + { + Id = "sub_new_3", + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([plan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); + + var result = await _command.Run(organization); + + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => + options.Items.Count == 4 && + options.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 15 && + options.Items[1].Price == plan.PasswordManager.StripeStoragePlanId && + options.Items[1].Quantity == 2 && + options.Items[2].Price == plan.SecretsManager.StripeSeatPlanId && + options.Items[2].Quantity == 10 && + options.Items[3].Price == plan.SecretsManager.StripeServiceAccountPlanId && + options.Items[3].Quantity == 100)); + + await _organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => + org.Id == organizationId && + org.GatewaySubscriptionId == "sub_new_3" && + org.Enabled == true)); + } + + [Fact] + public async Task Run_Organization_WithDisabledPlan_UpgradesToNewPlan_Success() + { + var organizationId = Guid.NewGuid(); + var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.EnterpriseAnnually2023 + }; + + var oldPlan = new DisabledEnterprisePlan2023(true); + var newPlan = MockPlans.Get(PlanType.EnterpriseAnnually); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_old", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = oldPlan.PasswordManager.StripeSeatPlanId }, Quantity = 20 }, + new SubscriptionItem { Price = new Price { Id = oldPlan.PasswordManager.StripeStoragePlanId }, Quantity = 5 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + var newSubscription = new Subscription + { + Id = "sub_upgraded", + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([oldPlan, newPlan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); + + var result = await _command.Run(organization); + + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => + options.Items.Count == 2 && + options.Items[0].Price == newPlan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 20 && + options.Items[1].Price == newPlan.PasswordManager.StripeStoragePlanId && + options.Items[1].Quantity == 5)); + + await _organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => + org.Id == organizationId && + org.GatewaySubscriptionId == "sub_upgraded" && + org.Enabled == true && + org.PlanType == PlanType.EnterpriseAnnually && + org.Plan == newPlan.Name && + org.SelfHost == newPlan.HasSelfHost && + org.UsePolicies == newPlan.HasPolicies && + org.UseGroups == newPlan.HasGroups && + org.UseDirectory == newPlan.HasDirectory && + org.UseEvents == newPlan.HasEvents && + org.UseTotp == newPlan.HasTotp && + org.Use2fa == newPlan.Has2fa && + org.UseApi == newPlan.HasApi && + org.UseSso == newPlan.HasSso && + org.UseOrganizationDomains == newPlan.HasOrganizationDomains && + org.UseKeyConnector == newPlan.HasKeyConnector && + org.UseScim == newPlan.HasScim && + org.UseResetPassword == newPlan.HasResetPassword && + org.UsersGetPremium == newPlan.UsersGetPremium && + org.UseCustomPermissions == newPlan.HasCustomPermissions)); + } + + [Fact] + public async Task Run_Organization_WithStorageAndSecretManagerButNoServiceAccounts_Success() + { + var organizationId = Guid.NewGuid(); + var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.TeamsAnnually + }; + + var plan = MockPlans.Get(PlanType.TeamsAnnually); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_complex", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, Quantity = 12 }, + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeStoragePlanId }, Quantity = 8 }, + new SubscriptionItem { Price = new Price { Id = plan.SecretsManager.StripeSeatPlanId }, Quantity = 6 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + var newSubscription = new Subscription + { + Id = "sub_complex", + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([plan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); + + var result = await _command.Run(organization); + + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => + options.Items.Count == 3 && + options.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 12 && + options.Items[1].Price == plan.PasswordManager.StripeStoragePlanId && + options.Items[1].Quantity == 8 && + options.Items[2].Price == plan.SecretsManager.StripeSeatPlanId && + options.Items[2].Quantity == 6)); + + await _organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => + org.Id == organizationId && + org.GatewaySubscriptionId == "sub_complex" && + org.Enabled == true)); + } + + [Fact] + public async Task Run_Organization_WithSecretsManagerOnly_NoServiceAccounts_Success() + { + var organizationId = Guid.NewGuid(); + var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var organization = new Organization + { + Id = organizationId, + PlanType = PlanType.TeamsMonthly + }; + + var plan = MockPlans.Get(PlanType.TeamsMonthly); + + var existingSubscription = new Subscription + { + Status = SubscriptionStatus.Canceled, + CustomerId = "cus_sm", + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Price = new Price { Id = plan.PasswordManager.StripeSeatPlanId }, Quantity = 8 }, + new SubscriptionItem { Price = new Price { Id = plan.SecretsManager.StripeSeatPlanId }, Quantity = 5 } + ] + }, + Metadata = new Dictionary { ["organizationId"] = organizationId.ToString() } + }; + + var newSubscription = new Subscription + { + Id = "sub_sm", + Items = new StripeList + { + Data = [new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd }] + } + }; + + _subscriberService.GetSubscription(organization).Returns(existingSubscription); + _pricingClient.ListPlans().Returns([plan]); + _stripeAdapter.CreateSubscriptionAsync(Arg.Any()).Returns(newSubscription); + + var result = await _command.Run(organization); + + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).CreateSubscriptionAsync(Arg.Is(options => + options.Items.Count == 2 && + options.Items[0].Price == plan.PasswordManager.StripeSeatPlanId && + options.Items[0].Quantity == 8 && + options.Items[1].Price == plan.SecretsManager.StripeSeatPlanId && + options.Items[1].Quantity == 5)); + + await _organizationRepository.Received(1).ReplaceAsync(Arg.Is(org => + org.Id == organizationId && + org.GatewaySubscriptionId == "sub_sm" && + org.Enabled == true)); + } + + private record DisabledEnterprisePlan2023 : Bit.Core.Models.StaticStore.Plan + { + public DisabledEnterprisePlan2023(bool isAnnual) + { + Type = PlanType.EnterpriseAnnually2023; + ProductTier = ProductTierType.Enterprise; + Name = "Enterprise (Annually) 2023"; + IsAnnual = isAnnual; + NameLocalizationKey = "planNameEnterprise"; + DescriptionLocalizationKey = "planDescEnterprise"; + CanBeUsedByBusiness = true; + TrialPeriodDays = 7; + HasPolicies = true; + HasSelfHost = true; + HasGroups = true; + HasDirectory = true; + HasEvents = true; + HasTotp = true; + Has2fa = true; + HasApi = true; + HasSso = true; + HasOrganizationDomains = true; + HasKeyConnector = true; + HasScim = true; + HasResetPassword = true; + UsersGetPremium = true; + HasCustomPermissions = true; + UpgradeSortOrder = 4; + DisplaySortOrder = 4; + LegacyYear = 2024; + Disabled = true; + + PasswordManager = new PasswordManagerFeatures(isAnnual); + SecretsManager = new SecretsManagerFeatures(isAnnual); + } + + private record SecretsManagerFeatures : SecretsManagerPlanFeatures + { + public SecretsManagerFeatures(bool isAnnual) + { + BaseSeats = 0; + BasePrice = 0; + BaseServiceAccount = 200; + HasAdditionalSeatsOption = true; + HasAdditionalServiceAccountOption = true; + AllowSeatAutoscale = true; + AllowServiceAccountsAutoscale = true; + + if (isAnnual) + { + StripeSeatPlanId = "secrets-manager-enterprise-seat-annually-2023"; + StripeServiceAccountPlanId = "secrets-manager-service-account-2023-annually"; + SeatPrice = 144; + AdditionalPricePerServiceAccount = 12; + } + else + { + StripeSeatPlanId = "secrets-manager-enterprise-seat-monthly-2023"; + StripeServiceAccountPlanId = "secrets-manager-service-account-2023-monthly"; + SeatPrice = 13; + AdditionalPricePerServiceAccount = 1; + } + } + } + + private record PasswordManagerFeatures : PasswordManagerPlanFeatures + { + public PasswordManagerFeatures(bool isAnnual) + { + BaseSeats = 0; + BaseStorageGb = 1; + HasAdditionalStorageOption = true; + HasAdditionalSeatsOption = true; + AllowSeatAutoscale = true; + + if (isAnnual) + { + AdditionalStoragePricePerGb = 4; + StripeStoragePlanId = "storage-gb-annually"; + StripeSeatPlanId = "2023-enterprise-org-seat-annually-old"; + SeatPrice = 72; + } + else + { + StripeSeatPlanId = "2023-enterprise-seat-monthly-old"; + StripeStoragePlanId = "storage-gb-monthly"; + SeatPrice = 7; + AdditionalStoragePricePerGb = 0.5M; + } + } + } } } diff --git a/test/Core.Test/Context/CurrentContextTests.cs b/test/Core.Test/Context/CurrentContextTests.cs index b868d6ceaa..41a54a5b22 100644 --- a/test/Core.Test/Context/CurrentContextTests.cs +++ b/test/Core.Test/Context/CurrentContextTests.cs @@ -107,30 +107,6 @@ public class CurrentContextTests Assert.Equal(deviceType, sutProvider.Sut.DeviceType); } - [Theory, BitAutoData] - public async Task BuildAsync_HttpContext_SetsCloudflareFlags( - SutProvider sutProvider) - { - var httpContext = new DefaultHttpContext(); - var globalSettings = new Core.Settings.GlobalSettings(); - sutProvider.Sut.BotScore = null; - // Arrange - var botScore = 85; - httpContext.Request.Headers["X-Cf-Bot-Score"] = botScore.ToString(); - httpContext.Request.Headers["X-Cf-Worked-Proxied"] = "1"; - httpContext.Request.Headers["X-Cf-Is-Bot"] = "1"; - httpContext.Request.Headers["X-Cf-Maybe-Bot"] = "1"; - - // Act - await sutProvider.Sut.BuildAsync(httpContext, globalSettings); - - // Assert - Assert.True(sutProvider.Sut.CloudflareWorkerProxied); - Assert.True(sutProvider.Sut.IsBot); - Assert.True(sutProvider.Sut.MaybeBot); - Assert.Equal(botScore, sutProvider.Sut.BotScore); - } - [Theory, BitAutoData] public async Task BuildAsync_HttpContext_SetsClientVersion( SutProvider sutProvider) diff --git a/test/Core.Test/Core.Test.csproj b/test/Core.Test/Core.Test.csproj index c0f91a7bd3..b9e218205c 100644 --- a/test/Core.Test/Core.Test.csproj +++ b/test/Core.Test/Core.Test.csproj @@ -28,6 +28,9 @@ + + + diff --git a/test/Core.Test/Dirt/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs b/test/Core.Test/Dirt/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs new file mode 100644 index 0000000000..37b303b735 --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/EventIntegrationServiceCollectionExtensionsTests.cs @@ -0,0 +1,915 @@ +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations.Interfaces; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations.Interfaces; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Dirt.Services.NoopImplementations; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Settings; +using Bit.Core.Test.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Utilities; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; +using NSubstitute; +using StackExchange.Redis; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations; + +public class EventIntegrationServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _services; + private readonly GlobalSettings _globalSettings; + + public EventIntegrationServiceCollectionExtensionsTests() + { + _services = new ServiceCollection(); + _globalSettings = CreateGlobalSettings([]); + + // Add required infrastructure services + _services.TryAddSingleton(_globalSettings); + _services.TryAddSingleton(_globalSettings); + _services.AddLogging(); + + // Mock Redis connection for cache + _services.AddSingleton(Substitute.For()); + + // Mock required repository dependencies for commands + _services.TryAddScoped(_ => Substitute.For()); + _services.TryAddScoped(_ => Substitute.For()); + _services.TryAddScoped(_ => Substitute.For()); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_RegistersAllServices() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + using var provider = _services.BuildServiceProvider(); + + var cache = provider.GetRequiredKeyedService(EventIntegrationsCacheConstants.CacheName); + Assert.NotNull(cache); + + var validator = provider.GetRequiredService(); + Assert.NotNull(validator); + + using var scope = provider.CreateScope(); + var sp = scope.ServiceProvider; + + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + Assert.NotNull(sp.GetService()); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries_AreRegisteredAsScoped() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var createIntegrationDescriptor = _services.First(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)); + var createConfigDescriptor = _services.First(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); + + Assert.Equal(ServiceLifetime.Scoped, createIntegrationDescriptor.Lifetime); + Assert.Equal(ServiceLifetime.Scoped, createConfigDescriptor.Lifetime); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries_DifferentInstancesPerScope() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var provider = _services.BuildServiceProvider(); + + ICreateOrganizationIntegrationCommand? instance1, instance2, instance3; + using (var scope1 = provider.CreateScope()) + { + instance1 = scope1.ServiceProvider.GetService(); + } + using (var scope2 = provider.CreateScope()) + { + instance2 = scope2.ServiceProvider.GetService(); + } + using (var scope3 = provider.CreateScope()) + { + instance3 = scope3.ServiceProvider.GetService(); + } + + Assert.NotNull(instance1); + Assert.NotNull(instance2); + Assert.NotNull(instance3); + Assert.NotSame(instance1, instance2); + Assert.NotSame(instance2, instance3); + Assert.NotSame(instance1, instance3); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_CommandsQueries__SameInstanceWithinScope() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + var provider = _services.BuildServiceProvider(); + + using var scope = provider.CreateScope(); + var instance1 = scope.ServiceProvider.GetService(); + var instance2 = scope.ServiceProvider.GetService(); + + Assert.NotNull(instance1); + Assert.NotNull(instance2); + Assert.Same(instance1, instance2); + } + + [Fact] + public void AddEventIntegrationsCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + _services.AddEventIntegrationsCommandsQueries(_globalSettings); + + var createConfigCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); + Assert.Single(createConfigCmdDescriptors); + + var updateIntegrationCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(IUpdateOrganizationIntegrationCommand)).ToList(); + Assert.Single(updateIntegrationCmdDescriptors); + } + + [Fact] + public void AddOrganizationIntegrationCommandsQueries_RegistersAllIntegrationServices() + { + _services.AddOrganizationIntegrationCommandsQueries(); + + Assert.Contains(_services, s => s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IUpdateOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IDeleteOrganizationIntegrationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IGetOrganizationIntegrationsQuery)); + } + + [Fact] + public void AddOrganizationIntegrationCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddOrganizationIntegrationCommandsQueries(); + _services.AddOrganizationIntegrationCommandsQueries(); + _services.AddOrganizationIntegrationCommandsQueries(); + + var createCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationCommand)).ToList(); + Assert.Single(createCmdDescriptors); + } + + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_RegistersAllConfigurationServices() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + Assert.Contains(_services, s => s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IUpdateOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IDeleteOrganizationIntegrationConfigurationCommand)); + Assert.Contains(_services, s => s.ServiceType == typeof(IGetOrganizationIntegrationConfigurationsQuery)); + } + + [Fact] + public void AddOrganizationIntegrationConfigurationCommandsQueries_MultipleCalls_IsIdempotent() + { + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + _services.AddOrganizationIntegrationConfigurationCommandsQueries(); + + var createCmdDescriptors = _services.Where(s => + s.ServiceType == typeof(ICreateOrganizationIntegrationConfigurationCommand)).ToList(); + Assert.Single(createCmdDescriptors); + } + + [Fact] + public void IsRabbitMqEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingHostName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingUsername_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = null, + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingPassword_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = null, + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingEventExchangeName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = null, + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsRabbitMqEnabled_MissingIntegrationExchangeName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsRabbitMqEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_AllSettingsPresent_ReturnsTrue() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + Assert.True(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingConnectionString_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = null, + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingEventTopicName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = null, + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void IsAzureServiceBusEnabled_MissingIntegrationTopicName_ReturnsFalse() + { + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = null + }); + + Assert.False(EventIntegrationsServiceCollectionExtensions.IsAzureServiceBusEnabled(globalSettings)); + } + + [Fact] + public void AddSlackService_AllSettingsPresent_RegistersSlackService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = "test-client-id", + ["GlobalSettings:Slack:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Slack:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddSlackService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Slack:ClientId"] = null, + ["GlobalSettings:Slack:ClientSecret"] = null, + ["GlobalSettings:Slack:Scopes"] = null + }); + + services.AddSlackService(globalSettings); + + var provider = services.BuildServiceProvider(); + var slackService = provider.GetService(); + + Assert.NotNull(slackService); + Assert.IsType(slackService); + } + + [Fact] + public void AddTeamsService_AllSettingsPresent_RegistersTeamsServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = "test-client-id", + ["GlobalSettings:Teams:ClientSecret"] = "test-client-secret", + ["GlobalSettings:Teams:Scopes"] = "test-scopes" + }); + + services.TryAddSingleton(globalSettings); + services.AddLogging(); + services.TryAddScoped(_ => Substitute.For()); + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + + var teamsService = provider.GetService(); + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + + var bot = provider.GetService(); + Assert.NotNull(bot); + Assert.IsType(bot); + + var adapter = provider.GetService(); + Assert.NotNull(adapter); + Assert.IsType(adapter); + + var httpClientDescriptor = services.FirstOrDefault(s => + s.ServiceType == typeof(IHttpClientFactory)); + Assert.NotNull(httpClientDescriptor); + } + + [Fact] + public void AddTeamsService_SettingsMissing_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Teams:ClientId"] = null, + ["GlobalSettings:Teams:ClientSecret"] = null, + ["GlobalSettings:Teams:Scopes"] = null + }); + + services.AddTeamsService(globalSettings); + + var provider = services.BuildServiceProvider(); + var teamsService = provider.GetService(); + + Assert.NotNull(teamsService); + Assert.IsType(teamsService); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddRabbitMqIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.TryAddSingleton(TimeProvider.System); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddRabbitMqIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddRabbitMqIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventIntegrationHandler() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusIntegration(listenerConfig); + + var provider = services.BuildServiceProvider(); + var handler = provider.GetRequiredKeyedService(listenerConfig.RoutingKey); + + Assert.NotNull(handler); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersEventListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddAzureServiceBusIntegration_RegistersIntegrationListenerService() + { + var services = new ServiceCollection(); + var listenerConfig = new TestListenerConfiguration(); + + // Add required dependencies + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddKeyedSingleton(EventIntegrationsCacheConstants.CacheName, Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For()); + services.TryAddSingleton(Substitute.For>()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddAzureServiceBusIntegration(listenerConfig); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // AddAzureServiceBusIntegration should register 2 hosted services (Event + Integration listeners) + Assert.Equal(2, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_RegistersCommonServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify common services are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationFilterService)); + Assert.Contains(services, s => s.ServiceType == typeof(TimeProvider)); + + // Verify HttpClients for handlers are registered + var httpClientDescriptors = services.Where(s => s.ServiceType == typeof(IHttpClientFactory)).ToList(); + Assert.NotEmpty(httpClientDescriptors); + } + + [Fact] + public void AddEventIntegrationServices_RegistersIntegrationHandlers() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddEventIntegrationServices(globalSettings); + + // Verify integration handlers are registered + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + Assert.Contains(services, s => s.ServiceType == typeof(IIntegrationHandler)); + } + + [Fact] + public void AddEventIntegrationServices_RabbitMqEnabled_RegistersRabbitMqListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for RabbitMQ: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_AzureServiceBusEnabled_RegistersAzureListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_BothEnabled_AzureServiceBusTakesPrecedence() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration", + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register 11 hosted services for Azure Service Bus: 1 repository + 5*2 integration listeners (event+integration) + // NO RabbitMQ services should be enabled because ASB takes precedence + Assert.Equal(11, afterCount - beforeCount); + } + + [Fact] + public void AddEventIntegrationServices_NeitherEnabled_RegistersNoListeners() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + var beforeCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + services.AddEventIntegrationServices(globalSettings); + var afterCount = services.Count(s => s.ServiceType == typeof(IHostedService)); + + // Should register no hosted services when neither RabbitMQ nor Azure Service Bus is enabled + Assert.Equal(0, afterCount - beforeCount); + } + + [Fact] + public void AddEventWriteServices_AzureServiceBusEnabled_RegistersAzureServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(EventIntegrationEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_EventsConnectionStringPresent_RegistersAzureQueueService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:Events:ConnectionString"] = "DefaultEndpointsProtocol=https;AccountName=test;AccountKey=test;EndpointSuffix=core.windows.net", + ["GlobalSettings:Events:QueueName"] = "event" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(AzureQueueEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_SelfHosted_RegistersRepositoryService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:SelfHosted"] = "true" + }); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(RepositoryEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_NothingEnabled_RegistersNoopService() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + services.AddEventWriteServices(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IEventWriteService) && s.ImplementationType == typeof(NoopEventWriteService)); + } + + [Fact] + public void AddEventWriteServices_AzureTakesPrecedenceOverRabbitMq() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration", + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + services.AddEventWriteServices(globalSettings); + + // Should use Azure Service Bus, not RabbitMQ + Assert.Contains(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(AzureServiceBusService)); + Assert.DoesNotContain(services, s => s.ServiceType == typeof(IEventIntegrationPublisher) && s.ImplementationType == typeof(RabbitMqService)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusEnabled_RegistersAzureServiceBusServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:AzureServiceBus:ConnectionString"] = "Endpoint=sb://test.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=test", + ["GlobalSettings:EventLogging:AzureServiceBus:EventTopicName"] = "events", + ["GlobalSettings:EventLogging:AzureServiceBus:IntegrationTopicName"] = "integration" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddAzureServiceBusListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IAzureServiceBusService)); + Assert.Contains(services, s => s.ServiceType == typeof(IEventRepository)); + Assert.Contains(services, s => s.ServiceType == typeof(AzureTableStorageEventHandler)); + } + + [Fact] + public void AddAzureServiceBusListeners_AzureServiceBusDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddAzureServiceBusListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqEnabled_RegistersRabbitMqServices() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings(new Dictionary + { + ["GlobalSettings:EventLogging:RabbitMq:HostName"] = "localhost", + ["GlobalSettings:EventLogging:RabbitMq:Username"] = "user", + ["GlobalSettings:EventLogging:RabbitMq:Password"] = "pass", + ["GlobalSettings:EventLogging:RabbitMq:EventExchangeName"] = "exchange", + ["GlobalSettings:EventLogging:RabbitMq:IntegrationExchangeName"] = "integration" + }); + + // Add prerequisites + services.TryAddSingleton(globalSettings); + services.TryAddSingleton(Substitute.For()); + services.AddLogging(); + + services.AddRabbitMqListeners(globalSettings); + + Assert.Contains(services, s => s.ServiceType == typeof(IRabbitMqService)); + Assert.Contains(services, s => s.ServiceType == typeof(EventRepositoryHandler)); + } + + [Fact] + public void AddRabbitMqListeners_RabbitMqDisabled_ReturnsEarly() + { + var services = new ServiceCollection(); + var globalSettings = CreateGlobalSettings([]); + + var initialCount = services.Count; + services.AddRabbitMqListeners(globalSettings); + var finalCount = services.Count; + + Assert.Equal(initialCount, finalCount); + } + + private static GlobalSettings CreateGlobalSettings(Dictionary data) + { + var config = new ConfigurationBuilder() + .AddInMemoryCollection(data) + .Build(); + + var settings = new GlobalSettings(); + config.GetSection("GlobalSettings").Bind(settings); + return settings; + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..3ad3569c07 --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/CreateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,180 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class CreateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task CreateAsync_Success_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_WildcardSuccess_CreatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .CreateAsync(configuration) + .Returns(configuration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(configuration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(organizationId, integrationId, configuration)); + + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..c053a761bb --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/DeleteOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,212 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class DeleteOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_Success_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + configuration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_WildcardSuccess_DeletesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = integrationId; + configuration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(configuration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration configuration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + configuration.Id = configurationId; + configuration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(configuration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId, configurationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs new file mode 100644 index 0000000000..780467a91a --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/GetOrganizationIntegrationConfigurationsQueryTests.cs @@ -0,0 +1,101 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class GetOrganizationIntegrationConfigurationsQueryTests +{ + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_Success_ReturnsConfigurations( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + List configurations) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns(configurations); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetManyByIntegrationAsync(integrationId); + Assert.Equal(configurations.Count, result.Count); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_NoConfigurations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetManyByIntegrationAsync(integrationId) + .Returns([]); + + var result = await sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId); + + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task GetManyByIntegrationAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.GetManyByIntegrationAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .GetManyByIntegrationAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs new file mode 100644 index 0000000000..42ea278aa6 --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrationConfigurations/UpdateOrganizationIntegrationConfigurationCommandTests.cs @@ -0,0 +1,391 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrationConfigurations; + +[SutProviderCustomize] +public class UpdateOrganizationIntegrationConfigurationCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_Success_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + // Also verify RemoveByTagAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_WildcardSuccess_UpdatesConfigurationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = null; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + // Also verify RemoveAsync was NOT called + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedEventType_UpdatesConfigurationAndInvalidatesCacheForBothTypes( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.EventType = EventType.Cipher_Created; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(configurationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedConfiguration); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + existingConfiguration.EventType.Value)); + await sutProvider.GetDependency().Received(1) + .RemoveAsync(EventIntegrationsCacheConstants.BuildCacheKeyForOrganizationIntegrationConfigurationDetails( + organizationId, + integration.Type, + updatedConfiguration.EventType.Value)); + // Verify RemoveByTagAsync was NOT called since both are specific event types + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + Assert.Equal(updatedConfiguration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns((OrganizationIntegrationConfiguration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ConfigurationDoesNotBelongToIntegration_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = Guid.NewGuid(); // Different integration + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ValidationFails_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + Guid configurationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration) + { + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(false); + + integration.Id = integrationId; + integration.OrganizationId = organizationId; + existingConfiguration.Id = configurationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Id = configurationId; + updatedConfiguration.OrganizationIntegrationId = integrationId; + updatedConfiguration.Template = "template"; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(configurationId) + .Returns(existingConfiguration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, configurationId, updatedConfiguration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromWildcardToSpecific_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = null; // Wildcard + updatedConfiguration.EventType = EventType.User_LoggedIn; // Specific + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_ChangedFromSpecificToWildcard_InvalidatesAllCaches( + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration, + OrganizationIntegrationConfiguration existingConfiguration, + OrganizationIntegrationConfiguration updatedConfiguration, + SutProvider sutProvider) + { + integration.OrganizationId = organizationId; + existingConfiguration.OrganizationIntegrationId = integrationId; + existingConfiguration.EventType = EventType.User_LoggedIn; // Specific + updatedConfiguration.EventType = null; // Wildcard + + sutProvider.GetDependency() + .GetByIdAsync(integrationId).Returns(integration); + sutProvider.GetDependency() + .GetByIdAsync(existingConfiguration.Id).Returns(existingConfiguration); + sutProvider.GetDependency() + .ValidateConfiguration(Arg.Any(), Arg.Any()) + .Returns(true); + + await sutProvider.Sut.UpdateAsync(organizationId, integrationId, existingConfiguration.Id, updatedConfiguration); + + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + await sutProvider.GetDependency().DidNotReceive() + .RemoveAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..4933656eb3 --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/CreateOrganizationIntegrationCommandTests.cs @@ -0,0 +1,92 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class CreateOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task CreateAsync_Success_CreatesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(integration) + .Returns(integration); + + var result = await sutProvider.Sut.CreateAsync(integration); + + await sutProvider.GetDependency().Received(1) + .GetManyByOrganizationAsync(integration.OrganizationId); + await sutProvider.GetDependency().Received(1) + .CreateAsync(integration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + integration.OrganizationId, + integration.Type)); + Assert.Equal(integration, result); + } + + [Theory, BitAutoData] + public async Task CreateAsync_DuplicateType_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration existingIntegration) + { + integration.Type = IntegrationType.Webhook; + existingIntegration.Type = IntegrationType.Webhook; + existingIntegration.OrganizationId = integration.OrganizationId; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([existingIntegration]); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAsync(integration)); + + Assert.Contains("An integration of this type already exists", exception.Message); + await sutProvider.GetDependency().DidNotReceive() + .CreateAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAsync_DifferentType_Success( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration existingIntegration) + { + integration.Type = IntegrationType.Webhook; + existingIntegration.Type = IntegrationType.Slack; + existingIntegration.OrganizationId = integration.OrganizationId; + + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([existingIntegration]); + sutProvider.GetDependency() + .CreateAsync(integration) + .Returns(integration); + + var result = await sutProvider.Sut.CreateAsync(integration); + + await sutProvider.GetDependency().Received(1) + .CreateAsync(integration); + Assert.Equal(integration, result); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..15a3b44bcf --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/DeleteOrganizationIntegrationCommandTests.cs @@ -0,0 +1,86 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class DeleteOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task DeleteAsync_Success_DeletesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = organizationId; + integration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await sutProvider.Sut.DeleteAsync(organizationId, integrationId); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(integration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + integration.Type)); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration integration) + { + integration.Id = integrationId; + integration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(integration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteAsync(organizationId, integrationId)); + + await sutProvider.GetDependency().DidNotReceive() + .DeleteAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs new file mode 100644 index 0000000000..19b35ac340 --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/GetOrganizationIntegrationsQueryTests.cs @@ -0,0 +1,44 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class GetOrganizationIntegrationsQueryTests +{ + [Theory, BitAutoData] + public async Task GetManyByOrganizationAsync_CallsRepository( + SutProvider sutProvider, + Guid organizationId, + List integrations) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns(integrations); + + var result = await sutProvider.Sut.GetManyByOrganizationAsync(organizationId); + + await sutProvider.GetDependency().Received(1) + .GetManyByOrganizationAsync(organizationId); + Assert.Equal(integrations.Count, result.Count); + } + + [Theory, BitAutoData] + public async Task GetManyByOrganizationAsync_NoIntegrations_ReturnsEmptyList( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + + var result = await sutProvider.Sut.GetManyByOrganizationAsync(organizationId); + + Assert.Empty(result); + } +} diff --git a/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs new file mode 100644 index 0000000000..34bf02c34b --- /dev/null +++ b/test/Core.Test/Dirt/EventIntegrations/OrganizationIntegrations/UpdateOrganizationIntegrationCommandTests.cs @@ -0,0 +1,121 @@ +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.EventIntegrations.OrganizationIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.EventIntegrations.OrganizationIntegrations; + +[SutProviderCustomize] +public class UpdateOrganizationIntegrationCommandTests +{ + [Theory, BitAutoData] + public async Task UpdateAsync_Success_UpdatesIntegrationAndInvalidatesCache( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = organizationId; + existingIntegration.Type = IntegrationType.Webhook; + updatedIntegration.Id = integrationId; + updatedIntegration.OrganizationId = organizationId; + updatedIntegration.Type = IntegrationType.Webhook; + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + var result = await sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(integrationId); + await sutProvider.GetDependency().Received(1) + .ReplaceAsync(updatedIntegration); + await sutProvider.GetDependency().Received(1) + .RemoveByTagAsync(EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + organizationId, + existingIntegration.Type)); + Assert.Equal(updatedIntegration, result); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotExist_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration updatedIntegration) + { + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns((OrganizationIntegration)null); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = Guid.NewGuid(); // Different organization + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task UpdateAsync_IntegrationIsDifferentType_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + Guid integrationId, + OrganizationIntegration existingIntegration, + OrganizationIntegration updatedIntegration) + { + existingIntegration.Id = integrationId; + existingIntegration.OrganizationId = organizationId; + existingIntegration.Type = IntegrationType.Webhook; + updatedIntegration.Id = integrationId; + updatedIntegration.OrganizationId = organizationId; + updatedIntegration.Type = IntegrationType.Hec; // Different Type + + sutProvider.GetDependency() + .GetByIdAsync(integrationId) + .Returns(existingIntegration); + + await Assert.ThrowsAsync( + () => sutProvider.Sut.UpdateAsync(organizationId, integrationId, updatedIntegration)); + + await sutProvider.GetDependency().DidNotReceive() + .ReplaceAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceive() + .RemoveByTagAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResultTests.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResultTests.cs new file mode 100644 index 0000000000..4b6292b7c4 --- /dev/null +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationHandlerResultTests.cs @@ -0,0 +1,128 @@ +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; + +public class IntegrationHandlerResultTests +{ + [Theory, BitAutoData] + public void Succeed_SetsSuccessTrue_CategoryNull(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Succeed(message); + + Assert.True(result.Success); + Assert.Null(result.Category); + Assert.Equal(message, result.Message); + Assert.Null(result.FailureReason); + } + + [Theory, BitAutoData] + public void Fail_WithCategory_SetsSuccessFalse_CategorySet(IntegrationMessage message) + { + var category = IntegrationFailureCategory.AuthenticationFailed; + var failureReason = "Invalid credentials"; + + var result = IntegrationHandlerResult.Fail(message, category, failureReason); + + Assert.False(result.Success); + Assert.Equal(category, result.Category); + Assert.Equal(failureReason, result.FailureReason); + Assert.Equal(message, result.Message); + } + + [Theory, BitAutoData] + public void Fail_WithDelayUntil_SetsDelayUntilDate(IntegrationMessage message) + { + var delayUntil = DateTime.UtcNow.AddMinutes(5); + + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.RateLimited, + "Rate limited", + delayUntil + ); + + Assert.Equal(delayUntil, result.DelayUntilDate); + } + + [Theory, BitAutoData] + public void Retryable_RateLimited_ReturnsTrue(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.RateLimited, + "Rate limited" + ); + + Assert.True(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_TransientError_ReturnsTrue(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.TransientError, + "Temporary network issue" + ); + + Assert.True(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_AuthenticationFailed_ReturnsFalse(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.AuthenticationFailed, + "Invalid token" + ); + + Assert.False(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_ConfigurationError_ReturnsFalse(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.ConfigurationError, + "Channel not found" + ); + + Assert.False(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_ServiceUnavailable_ReturnsTrue(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.ServiceUnavailable, + "Service is down" + ); + + Assert.True(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_PermanentFailure_ReturnsFalse(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Fail( + message, + IntegrationFailureCategory.PermanentFailure, + "Permanent failure" + ); + + Assert.False(result.Retryable); + } + + [Theory, BitAutoData] + public void Retryable_SuccessCase_ReturnsFalse(IntegrationMessage message) + { + var result = IntegrationHandlerResult.Succeed(message); + + Assert.False(result.Retryable); + } +} diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationMessageTests.cs similarity index 86% rename from test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs rename to test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationMessageTests.cs index edd5cd488f..6f0ce11db8 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationMessageTests.cs @@ -1,13 +1,14 @@ using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Xunit; -namespace Bit.Core.Test.Models.Data.EventIntegrations; +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; public class IntegrationMessageTests { private const string _messageId = "TestMessageId"; + private const string _organizationId = "TestOrganizationId"; [Fact] public void ApplyRetry_IncrementsRetryCountAndSetsDelayUntilDate() @@ -16,6 +17,7 @@ public class IntegrationMessageTests { Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), MessageId = _messageId, + OrganizationId = _organizationId, RetryCount = 2, RenderedTemplate = string.Empty, DelayUntilDate = null @@ -36,6 +38,7 @@ public class IntegrationMessageTests { Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), MessageId = _messageId, + OrganizationId = _organizationId, RenderedTemplate = "This is the message", IntegrationType = IntegrationType.Webhook, RetryCount = 2, @@ -48,6 +51,7 @@ public class IntegrationMessageTests Assert.NotNull(result); Assert.Equal(message.Configuration, result.Configuration); Assert.Equal(message.MessageId, result.MessageId); + Assert.Equal(message.OrganizationId, result.OrganizationId); Assert.Equal(message.RenderedTemplate, result.RenderedTemplate); Assert.Equal(message.IntegrationType, result.IntegrationType); Assert.Equal(message.RetryCount, result.RetryCount); @@ -67,6 +71,7 @@ public class IntegrationMessageTests var message = new IntegrationMessage { MessageId = _messageId, + OrganizationId = _organizationId, RenderedTemplate = "This is the message", IntegrationType = IntegrationType.Webhook, RetryCount = 2, @@ -77,6 +82,7 @@ public class IntegrationMessageTests var result = JsonSerializer.Deserialize(json); Assert.Equal(message.MessageId, result.MessageId); + Assert.Equal(message.OrganizationId, result.OrganizationId); Assert.Equal(message.RenderedTemplate, result.RenderedTemplate); Assert.Equal(message.IntegrationType, result.IntegrationType); Assert.Equal(message.RetryCount, result.RetryCount); diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs similarity index 94% rename from test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs rename to test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs index 8605a3dcab..a3e05ffe37 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationOAuthStateTests.cs @@ -1,12 +1,12 @@ #nullable enable -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Time.Testing; using Xunit; -namespace Bit.Core.Test.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; public class IntegrationOAuthStateTests { diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs similarity index 55% rename from test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs rename to test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs index 930b04121c..7bacb4046b 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/IntegrationTemplateContextTests.cs @@ -1,13 +1,13 @@ #nullable enable using System.Text.Json; using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Entities; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; public class IntegrationTemplateContextTests { @@ -21,7 +21,21 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void UserName_WhenUserIsSet_ReturnsName(EventMessage eventMessage, User user) + public void DateIso8601_ReturnsIso8601FormattedDate(EventMessage eventMessage) + { + var testDate = new DateTime(2025, 10, 27, 13, 30, 0, DateTimeKind.Utc); + eventMessage.Date = testDate; + var sut = new IntegrationTemplateContext(eventMessage); + + var result = sut.DateIso8601; + + Assert.Equal("2025-10-27T13:30:00.0000000Z", result); + // Verify it's valid ISO 8601 + Assert.True(DateTime.TryParse(result, out _)); + } + + [Theory, BitAutoData] + public void UserName_WhenUserIsSet_ReturnsName(EventMessage eventMessage, OrganizationUserUserDetails user) { var sut = new IntegrationTemplateContext(eventMessage) { User = user }; @@ -37,7 +51,7 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void UserEmail_WhenUserIsSet_ReturnsEmail(EventMessage eventMessage, User user) + public void UserEmail_WhenUserIsSet_ReturnsEmail(EventMessage eventMessage, OrganizationUserUserDetails user) { var sut = new IntegrationTemplateContext(eventMessage) { User = user }; @@ -53,7 +67,23 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void ActingUserName_WhenActingUserIsSet_ReturnsName(EventMessage eventMessage, User actingUser) + public void UserType_WhenUserIsSet_ReturnsType(EventMessage eventMessage, OrganizationUserUserDetails user) + { + var sut = new IntegrationTemplateContext(eventMessage) { User = user }; + + Assert.Equal(user.Type, sut.UserType); + } + + [Theory, BitAutoData] + public void UserType_WhenUserIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { User = null }; + + Assert.Null(sut.UserType); + } + + [Theory, BitAutoData] + public void ActingUserName_WhenActingUserIsSet_ReturnsName(EventMessage eventMessage, OrganizationUserUserDetails actingUser) { var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; @@ -69,7 +99,7 @@ public class IntegrationTemplateContextTests } [Theory, BitAutoData] - public void ActingUserEmail_WhenActingUserIsSet_ReturnsEmail(EventMessage eventMessage, User actingUser) + public void ActingUserEmail_WhenActingUserIsSet_ReturnsEmail(EventMessage eventMessage, OrganizationUserUserDetails actingUser) { var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; @@ -84,6 +114,22 @@ public class IntegrationTemplateContextTests Assert.Null(sut.ActingUserEmail); } + [Theory, BitAutoData] + public void ActingUserType_WhenActingUserIsSet_ReturnsType(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = actingUser }; + + Assert.Equal(actingUser.Type, sut.ActingUserType); + } + + [Theory, BitAutoData] + public void ActingUserType_WhenActingUserIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { ActingUser = null }; + + Assert.Null(sut.ActingUserType); + } + [Theory, BitAutoData] public void OrganizationName_WhenOrganizationIsSet_ReturnsDisplayName(EventMessage eventMessage, Organization organization) { @@ -99,4 +145,20 @@ public class IntegrationTemplateContextTests Assert.Null(sut.OrganizationName); } + + [Theory, BitAutoData] + public void GroupName_WhenGroupIsSet_ReturnsName(EventMessage eventMessage, Group group) + { + var sut = new IntegrationTemplateContext(eventMessage) { Group = group }; + + Assert.Equal(group.Name, sut.GroupName); + } + + [Theory, BitAutoData] + public void GroupName_WhenGroupIsNull_ReturnsNull(EventMessage eventMessage) + { + var sut = new IntegrationTemplateContext(eventMessage) { Group = null }; + + Assert.Null(sut.GroupName); + } } diff --git a/test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetailsTests.cs similarity index 97% rename from test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs rename to test/Core.Test/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetailsTests.cs index 4b8cd4f47c..ae574d7ee6 100644 --- a/test/Core.Test/AdminConsole/Models/Data/Organizations/OrganizationIntegrationConfigurationDetailsTests.cs +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/OrganizationIntegrationConfigurationDetailsTests.cs @@ -1,8 +1,8 @@ using System.Text.Json; -using Bit.Core.Models.Data.Organizations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; using Xunit; -namespace Bit.Core.Test.Models.Data.Organizations; +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; public class OrganizationIntegrationConfigurationDetailsTests { diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs b/test/Core.Test/Dirt/Models/Data/EventIntegrations/TestListenerConfiguration.cs similarity index 79% rename from test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs rename to test/Core.Test/Dirt/Models/Data/EventIntegrations/TestListenerConfiguration.cs index 916fe981de..2c811e06f5 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/TestListenerConfiguration.cs +++ b/test/Core.Test/Dirt/Models/Data/EventIntegrations/TestListenerConfiguration.cs @@ -1,6 +1,7 @@ -using Bit.Core.Enums; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; -namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; +namespace Bit.Core.Test.Dirt.Models.Data.EventIntegrations; public class TestListenerConfiguration : IIntegrationListenerConfiguration { @@ -17,4 +18,5 @@ public class TestListenerConfiguration : IIntegrationListenerConfiguration public int EventPrefetchCount => 0; public int IntegrationMaxConcurrentCalls => 1; public int IntegrationPrefetchCount => 0; + public string RoutingKey => IntegrationType.ToRoutingKey(); } diff --git a/test/Core.Test/Dirt/Models/Data/Teams/TeamsBotCredentialProviderTests.cs b/test/Core.Test/Dirt/Models/Data/Teams/TeamsBotCredentialProviderTests.cs new file mode 100644 index 0000000000..24576899d5 --- /dev/null +++ b/test/Core.Test/Dirt/Models/Data/Teams/TeamsBotCredentialProviderTests.cs @@ -0,0 +1,56 @@ +using Bit.Core.Dirt.Models.Data.Teams; +using Microsoft.Bot.Connector.Authentication; +using Xunit; + +namespace Bit.Core.Test.Dirt.Models.Data.Teams; + +public class TeamsBotCredentialProviderTests +{ + private string _clientId = "client id"; + private string _clientSecret = "client secret"; + + [Fact] + public async Task IsValidAppId_MustMatchClientId() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + + Assert.True(await sut.IsValidAppIdAsync(_clientId)); + Assert.False(await sut.IsValidAppIdAsync("Different id")); + } + + [Fact] + public async Task GetAppPasswordAsync_MatchingClientId_ReturnsClientSecret() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + var password = await sut.GetAppPasswordAsync(_clientId); + Assert.Equal(_clientSecret, password); + } + + [Fact] + public async Task GetAppPasswordAsync_NotMatchingClientId_ReturnsNull() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + Assert.Null(await sut.GetAppPasswordAsync("Different id")); + } + + [Fact] + public async Task IsAuthenticationDisabledAsync_ReturnsFalse() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + Assert.False(await sut.IsAuthenticationDisabledAsync()); + } + + [Fact] + public async Task ValidateIssuerAsync_ExpectedIssuer_ReturnsTrue() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + Assert.True(await sut.ValidateIssuerAsync(AuthenticationConstants.ToBotFromChannelTokenIssuer)); + } + + [Fact] + public async Task ValidateIssuerAsync_UnexpectedIssuer_ReturnsFalse() + { + var sut = new TeamsBotCredentialProvider(_clientId, _clientSecret); + Assert.False(await sut.ValidateIssuerAsync("unexpected issuer")); + } +} diff --git a/test/Core.Test/AdminConsole/Services/AzureQueueEventWriteServiceTests.cs b/test/Core.Test/Dirt/Services/AzureQueueEventWriteServiceTests.cs similarity index 100% rename from test/Core.Test/AdminConsole/Services/AzureQueueEventWriteServiceTests.cs rename to test/Core.Test/Dirt/Services/AzureQueueEventWriteServiceTests.cs diff --git a/test/Core.Test/AdminConsole/Services/AzureServiceBusEventListenerServiceTests.cs b/test/Core.Test/Dirt/Services/AzureServiceBusEventListenerServiceTests.cs similarity index 96% rename from test/Core.Test/AdminConsole/Services/AzureServiceBusEventListenerServiceTests.cs rename to test/Core.Test/Dirt/Services/AzureServiceBusEventListenerServiceTests.cs index c6ef3063e2..92f0b16b3f 100644 --- a/test/Core.Test/AdminConsole/Services/AzureServiceBusEventListenerServiceTests.cs +++ b/test/Core.Test/Dirt/Services/AzureServiceBusEventListenerServiceTests.cs @@ -2,9 +2,10 @@ using System.Text.Json; using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; using Bit.Core.Models.Data; -using Bit.Core.Services; +using Bit.Core.Test.Dirt.Models.Data.EventIntegrations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -12,7 +13,7 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class AzureServiceBusEventListenerServiceTests diff --git a/test/Core.Test/AdminConsole/Services/AzureServiceBusIntegrationListenerServiceTests.cs b/test/Core.Test/Dirt/Services/AzureServiceBusIntegrationListenerServiceTests.cs similarity index 78% rename from test/Core.Test/AdminConsole/Services/AzureServiceBusIntegrationListenerServiceTests.cs rename to test/Core.Test/Dirt/Services/AzureServiceBusIntegrationListenerServiceTests.cs index 23627f3962..88688f49ff 100644 --- a/test/Core.Test/AdminConsole/Services/AzureServiceBusIntegrationListenerServiceTests.cs +++ b/test/Core.Test/Dirt/Services/AzureServiceBusIntegrationListenerServiceTests.cs @@ -2,8 +2,10 @@ using System.Text.Json; using Azure.Messaging.ServiceBus; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Services; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Test.Dirt.Models.Data.EventIntegrations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Logging; @@ -11,7 +13,7 @@ using NSubstitute; using NSubstitute.ExceptionExtensions; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class AzureServiceBusIntegrationListenerServiceTests @@ -78,8 +80,10 @@ public class AzureServiceBusIntegrationListenerServiceTests var sutProvider = GetSutProvider(); message.RetryCount = 0; - var result = new IntegrationHandlerResult(false, message); - result.Retryable = false; + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -89,6 +93,12 @@ public class AzureServiceBusIntegrationListenerServiceTests await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson())); await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any()); + _logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")), + Arg.Any(), + Arg.Any>()); } [Theory, BitAutoData] @@ -96,9 +106,10 @@ public class AzureServiceBusIntegrationListenerServiceTests { var sutProvider = GetSutProvider(); message.RetryCount = _config.MaxRetries; - var result = new IntegrationHandlerResult(false, message); - result.Retryable = true; - + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.TransientError, // Retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -108,6 +119,12 @@ public class AzureServiceBusIntegrationListenerServiceTests await _handler.Received(1).HandleAsync(Arg.Is(expected.ToJson())); await _serviceBusService.DidNotReceiveWithAnyArgs().PublishToRetryAsync(Arg.Any()); + _logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => (o.ToString() ?? "").Contains("Integration failure - non-recoverable error or max retries exceeded.")), + Arg.Any(), + Arg.Any>()); } [Theory, BitAutoData] @@ -116,8 +133,10 @@ public class AzureServiceBusIntegrationListenerServiceTests var sutProvider = GetSutProvider(); message.RetryCount = 0; - var result = new IntegrationHandlerResult(false, message); - result.Retryable = true; + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.TransientError, // Retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -133,7 +152,7 @@ public class AzureServiceBusIntegrationListenerServiceTests public async Task HandleMessageAsync_SuccessfulResult_Succeeds(IntegrationMessage message) { var sutProvider = GetSutProvider(); - var result = new IntegrationHandlerResult(true, message); + var result = IntegrationHandlerResult.Succeed(message); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -156,7 +175,7 @@ public class AzureServiceBusIntegrationListenerServiceTests _logger.Received(1).Log( LogLevel.Error, Arg.Any(), - Arg.Any(), + Arg.Is(o => (o.ToString() ?? "").Contains("Unhandled error processing ASB message")), Arg.Any(), Arg.Any>()); diff --git a/test/Core.Test/AdminConsole/Services/DatadogIntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/DatadogIntegrationHandlerTests.cs similarity index 97% rename from test/Core.Test/AdminConsole/Services/DatadogIntegrationHandlerTests.cs rename to test/Core.Test/Dirt/Services/DatadogIntegrationHandlerTests.cs index 5f0a9915bf..a8c5d7da95 100644 --- a/test/Core.Test/AdminConsole/Services/DatadogIntegrationHandlerTests.cs +++ b/test/Core.Test/Dirt/Services/DatadogIntegrationHandlerTests.cs @@ -1,8 +1,8 @@ #nullable enable using System.Net; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Services; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services.Implementations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -11,7 +11,7 @@ using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class DatadogIntegrationHandlerTests @@ -51,7 +51,7 @@ public class DatadogIntegrationHandlerTests Assert.True(result.Success); Assert.Equal(result.Message, message); - Assert.Empty(result.FailureReason); + Assert.Null(result.FailureReason); sutProvider.GetDependency().Received(1).CreateClient( Arg.Is(AssertHelper.AssertPropertyEqual(DatadogIntegrationHandler.HttpClientName)) diff --git a/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs b/test/Core.Test/Dirt/Services/EventIntegrationEventWriteServiceTests.cs similarity index 64% rename from test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs rename to test/Core.Test/Dirt/Services/EventIntegrationEventWriteServiceTests.cs index 9369690d86..3870601604 100644 --- a/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs +++ b/test/Core.Test/Dirt/Services/EventIntegrationEventWriteServiceTests.cs @@ -1,12 +1,13 @@ using System.Text.Json; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; using Bit.Core.Models.Data; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class EventIntegrationEventWriteServiceTests @@ -22,18 +23,34 @@ public class EventIntegrationEventWriteServiceTests [Theory, BitAutoData] public async Task CreateAsync_EventPublishedToEventQueue(EventMessage eventMessage) { - var expected = JsonSerializer.Serialize(eventMessage); await Subject.CreateAsync(eventMessage); await _eventIntegrationPublisher.Received(1).PublishEventAsync( - Arg.Is(body => AssertJsonStringsMatch(eventMessage, body))); + body: Arg.Is(body => AssertJsonStringsMatch(eventMessage, body)), + organizationId: Arg.Is(orgId => eventMessage.OrganizationId.ToString().Equals(orgId))); } [Theory, BitAutoData] public async Task CreateManyAsync_EventsPublishedToEventQueue(IEnumerable eventMessages) { + var eventMessage = eventMessages.First(); await Subject.CreateManyAsync(eventMessages); await _eventIntegrationPublisher.Received(1).PublishEventAsync( - Arg.Is(body => AssertJsonStringsMatch(eventMessages, body))); + body: Arg.Is(body => AssertJsonStringsMatch(eventMessages, body)), + organizationId: Arg.Is(orgId => eventMessage.OrganizationId.ToString().Equals(orgId))); + } + + [Fact] + public async Task CreateManyAsync_EmptyList_DoesNothing() + { + await Subject.CreateManyAsync([]); + await _eventIntegrationPublisher.DidNotReceiveWithAnyArgs().PublishEventAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task DisposeAsync_DisposesEventIntegrationPublisher() + { + await Subject.DisposeAsync(); + await _eventIntegrationPublisher.Received(1).DisposeAsync(); } private static bool AssertJsonStringsMatch(EventMessage expected, string body) diff --git a/test/Core.Test/Dirt/Services/EventIntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/EventIntegrationHandlerTests.cs new file mode 100644 index 0000000000..e15a254b39 --- /dev/null +++ b/test/Core.Test/Dirt/Services/EventIntegrationHandlerTests.cs @@ -0,0 +1,711 @@ +#nullable enable + +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Bit.Test.Common.Helpers; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using ZiggyCreatures.Caching.Fusion; + +namespace Bit.Core.Test.Dirt.Services; + +[SutProviderCustomize] +public class EventIntegrationHandlerTests +{ + private const string _templateBase = "Date: #Date#, Type: #Type#, UserId: #UserId#"; + private const string _templateWithGroup = "Group: #GroupName#"; + private const string _templateWithOrganization = "Org: #OrganizationName#"; + private const string _templateWithUser = "#UserName#, #UserEmail#, #UserType#"; + private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#, #ActingUserType#"; + private static readonly Guid _organizationId = Guid.NewGuid(); + private static readonly Uri _uri = new Uri("https://localhost"); + private static readonly Uri _uri2 = new Uri("https://example.com"); + private readonly IEventIntegrationPublisher _eventIntegrationPublisher = Substitute.For(); + private readonly ILogger> _logger = + Substitute.For>>(); + + private SutProvider> GetSutProvider( + List configurations) + { + var cache = Substitute.For(); + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>>>(), + options: Arg.Any(), + tags: Arg.Any>() + ).Returns(configurations); + + return new SutProvider>() + .SetDependency(cache) + .SetDependency(_eventIntegrationPublisher) + .SetDependency(IntegrationType.Webhook) + .SetDependency(_logger) + .Create(); + } + + private static IntegrationMessage ExpectedMessage(string template) + { + return new IntegrationMessage() + { + IntegrationType = IntegrationType.Webhook, + MessageId = "TestMessageId", + OrganizationId = _organizationId.ToString(), + Configuration = new WebhookIntegrationConfigurationDetails(_uri), + RenderedTemplate = template, + RetryCount = 0, + DelayUntilDate = null + }; + } + + private static List NoConfigurations() + { + return []; + } + + private static List OneConfiguration(string template) + { + var config = Substitute.For(); + config.Configuration = null; + config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); + config.Template = template; + + return [config]; + } + + private static List TwoConfigurations(string template) + { + var config = Substitute.For(); + config.Configuration = null; + config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); + config.Template = template; + var config2 = Substitute.For(); + config2.Configuration = null; + config2.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri2 }); + config2.Template = template; + + return [config, config2]; + } + + private static List InvalidFilterConfiguration() + { + var config = Substitute.For(); + config.Configuration = null; + config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); + config.Template = _templateBase; + config.Filters = "Invalid Configuration!"; + + return [config]; + } + + private static List ValidFilterConfiguration() + { + var config = Substitute.For(); + config.Configuration = null; + config.IntegrationConfiguration = JsonSerializer.Serialize(new { Uri = _uri }); + config.Template = _templateBase; + config.Filters = JsonSerializer.Serialize(new IntegrationFilterGroup()); + + return [config]; + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserIdPresent_UsesCache(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(actingUser); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Equal(actingUser, context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserOrganizationIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + eventMessage.ActingUserId ??= Guid.NewGuid(); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.ActingUser); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_ActingUserFactory_CallsOrganizationUserRepository(EventMessage eventMessage, OrganizationUserUserDetails actingUser) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithActingUser)); + var cache = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.ActingUserId ??= Guid.NewGuid(); + organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.ActingUserId.Value).Returns(actingUser); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(actingUser); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithActingUser); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationUserRepository.Received(1).GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.ActingUserId.Value); + Assert.Equal(actingUser, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupIdPresent_UsesCache(EventMessage eventMessage, Group group) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + + eventMessage.GroupId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(group); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Equal(group, context.Group); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + eventMessage.GroupId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.Group); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_GroupFactory_CallsGroupRepository(EventMessage eventMessage, Group group) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithGroup)); + var cache = sutProvider.GetDependency(); + var groupRepository = sutProvider.GetDependency(); + + eventMessage.GroupId ??= Guid.NewGuid(); + groupRepository.GetByIdAsync(eventMessage.GroupId.Value).Returns(group); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(group); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithGroup); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await groupRepository.Received(1).GetByIdAsync(eventMessage.GroupId.Value); + Assert.Equal(group, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationIdPresent_UsesCache(EventMessage eventMessage, Organization organization) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(organization); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Equal(organization, context.Organization); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + Assert.Null(context.Organization); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationFactory_CallsOrganizationRepository(EventMessage eventMessage, Organization organization) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); + var cache = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + organizationRepository.GetByIdAsync(eventMessage.OrganizationId.Value).Returns(organization); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(organization); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithOrganization); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationRepository.Received(1).GetByIdAsync(eventMessage.OrganizationId.Value); + Assert.Equal(organization, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserIdPresent_UsesCache(EventMessage eventMessage, OrganizationUserUserDetails userDetails) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ).Returns(userDetails); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.Received(1).GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Equal(userDetails, context.User); + } + + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId = null; + eventMessage.UserId ??= Guid.NewGuid(); + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Null(context.User); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_OrganizationUserIdNull_SkipsCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId = null; + + var context = await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + + Assert.Null(context.User); + } + + + [Theory, BitAutoData] + public async Task BuildContextAsync_UserFactory_CallsOrganizationUserRepository(EventMessage eventMessage, OrganizationUserUserDetails userDetails) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + organizationUserRepository.GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.UserId.Value).Returns(userDetails); + + // Capture the factory function passed to the cache + Func, CancellationToken, Task>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do, CancellationToken, Task>>(f => capturedFactory = f) + ).Returns(userDetails); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateWithUser); + + Assert.NotNull(capturedFactory); + var result = await capturedFactory(null!, CancellationToken.None); + + await organizationUserRepository.Received(1).GetDetailsByOrganizationIdUserIdAsync( + eventMessage.OrganizationId.Value, + eventMessage.UserId.Value); + Assert.Equal(userDetails, result); + } + + [Theory, BitAutoData] + public async Task BuildContextAsync_NoSpecialTokens_DoesNotCallAnyCache(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateWithUser)); + var cache = sutProvider.GetDependency(); + + eventMessage.ActingUserId ??= Guid.NewGuid(); + eventMessage.GroupId ??= Guid.NewGuid(); + eventMessage.OrganizationId ??= Guid.NewGuid(); + eventMessage.UserId ??= Guid.NewGuid(); + + await sutProvider.Sut.BuildContextAsync(eventMessage, _templateBase); + + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + await cache.DidNotReceive().GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any, CancellationToken, Task>>() + ); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_BaseTemplateNoConfigurations_DoesNothing(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + cache.GetOrSetAsync>( + Arg.Any(), + Arg.Any>>>(), + Arg.Any() + ).Returns(NoConfigurations()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_NoOrganizationId_DoesNothing(EventMessage eventMessage) + { + var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); + eventMessage.OrganizationId = null; + + await sutProvider.Sut.HandleEventAsync(eventMessage); + Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( + $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" + ); + + Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetDetailsByOrganizationIdUserIdAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( + $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" + ); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + + expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetByIdAsync(Arg.Any()); + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetDetailsByOrganizationIdUserIdAsync(Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_FilterReturnsFalse_DoesNothing(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(ValidFilterConfiguration()); + sutProvider.GetDependency().EvaluateFilterGroup( + Arg.Any(), Arg.Any()).Returns(false); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_FilterReturnsTrue_PublishesIntegrationMessage(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(ValidFilterConfiguration()); + sutProvider.GetDependency().EvaluateFilterGroup( + Arg.Any(), Arg.Any()).Returns(true); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + var expectedMessage = EventIntegrationHandlerTests.ExpectedMessage( + $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" + ); + + Assert.Single(_eventIntegrationPublisher.ReceivedCalls()); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_InvalidFilter_LogsErrorDoesNothing(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(InvalidFilterConfiguration()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); + _logger.Received(1).Log( + LogLevel.Error, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + + [Theory, BitAutoData] + public async Task HandleManyEventsAsync_BaseTemplateNoConfigurations_DoesNothing(List eventMessages) + { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); + var sutProvider = GetSutProvider(NoConfigurations()); + + await sutProvider.Sut.HandleManyEventsAsync(eventMessages); + Assert.Empty(_eventIntegrationPublisher.ReceivedCalls()); + } + + [Theory, BitAutoData] + public async Task HandleManyEventsAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessages(List eventMessages) + { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); + var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); + + await sutProvider.Sut.HandleManyEventsAsync(eventMessages); + + foreach (var eventMessage in eventMessages) + { + var expectedMessage = ExpectedMessage( + $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" + ); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId", "OrganizationId" }))); + } + } + + [Theory, BitAutoData] + public async Task HandleManyEventsAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages( + List eventMessages) + { + eventMessages.ForEach(e => e.OrganizationId = _organizationId); + var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); + + await sutProvider.Sut.HandleManyEventsAsync(eventMessages); + + foreach (var eventMessage in eventMessages) + { + var expectedMessage = ExpectedMessage( + $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" + ); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual( + expectedMessage, new[] { "MessageId", "OrganizationId" }))); + + expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual( + expectedMessage, new[] { "MessageId", "OrganizationId" }))); + } + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_CapturedFactories_CallConfigurationRepository(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + var configurationRepository = sutProvider.GetDependency(); + + var configs = OneConfiguration(_templateBase); + + configurationRepository.GetManyByEventTypeOrganizationIdIntegrationType(eventType: eventMessage.Type, organizationId: _organizationId, integrationType: IntegrationType.Webhook).Returns(configs); + + // Capture the factory function - there will be 1 call that returns both specific and wildcard matches + Func>, CancellationToken, Task>>? capturedFactory = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Do>, CancellationToken, Task>>>(f + => capturedFactory = f), + options: Arg.Any(), + tags: Arg.Any>() + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + // Verify factory was captured + Assert.NotNull(capturedFactory); + + // Execute the captured factory to trigger repository call + await capturedFactory(null!, CancellationToken.None); + + await configurationRepository.Received(1).GetManyByEventTypeOrganizationIdIntegrationType(eventType: eventMessage.Type, organizationId: _organizationId, integrationType: IntegrationType.Webhook); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_ConfigurationCacheOptions_SetsDurationToConstant(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + + FusionCacheEntryOptions? capturedOption = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>, CancellationToken, Task>>>(), + options: Arg.Do(opt => capturedOption = opt), + tags: Arg.Any?>() + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + Assert.NotNull(capturedOption); + Assert.Equal(EventIntegrationsCacheConstants.DurationForOrganizationIntegrationConfigurationDetails, + capturedOption.Duration); + } + + [Theory, BitAutoData] + public async Task HandleEventAsync_ConfigurationCache_AddsOrganizationIntegrationTag(EventMessage eventMessage) + { + eventMessage.OrganizationId = _organizationId; + var sutProvider = GetSutProvider(NoConfigurations()); + var cache = sutProvider.GetDependency(); + + IEnumerable? capturedTags = null; + cache.GetOrSetAsync( + key: Arg.Any(), + factory: Arg.Any>, CancellationToken, Task>>>(), + options: Arg.Any(), + tags: Arg.Do>(t => capturedTags = t) + ).Returns(new List()); + + await sutProvider.Sut.HandleEventAsync(eventMessage); + + var expectedTag = EventIntegrationsCacheConstants.BuildCacheTagForOrganizationIntegration( + _organizationId, + IntegrationType.Webhook + ); + Assert.NotNull(capturedTags); + Assert.Contains(expectedTag, capturedTags); + } +} diff --git a/test/Core.Test/AdminConsole/Services/EventRepositoryHandlerTests.cs b/test/Core.Test/Dirt/Services/EventRepositoryHandlerTests.cs similarity index 90% rename from test/Core.Test/AdminConsole/Services/EventRepositoryHandlerTests.cs rename to test/Core.Test/Dirt/Services/EventRepositoryHandlerTests.cs index 48c3a143d4..6392f0138d 100644 --- a/test/Core.Test/AdminConsole/Services/EventRepositoryHandlerTests.cs +++ b/test/Core.Test/Dirt/Services/EventRepositoryHandlerTests.cs @@ -1,4 +1,5 @@ -using Bit.Core.Models.Data; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Models.Data; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -6,7 +7,7 @@ using Bit.Test.Common.Helpers; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class EventRepositoryHandlerTests diff --git a/test/Core.Test/AdminConsole/Services/EventServiceTests.cs b/test/Core.Test/Dirt/Services/EventServiceTests.cs similarity index 100% rename from test/Core.Test/AdminConsole/Services/EventServiceTests.cs rename to test/Core.Test/Dirt/Services/EventServiceTests.cs diff --git a/test/Core.Test/AdminConsole/Services/IntegrationFilterFactoryTests.cs b/test/Core.Test/Dirt/Services/IntegrationFilterFactoryTests.cs similarity index 91% rename from test/Core.Test/AdminConsole/Services/IntegrationFilterFactoryTests.cs rename to test/Core.Test/Dirt/Services/IntegrationFilterFactoryTests.cs index b408bc1501..83780b1fe0 100644 --- a/test/Core.Test/AdminConsole/Services/IntegrationFilterFactoryTests.cs +++ b/test/Core.Test/Dirt/Services/IntegrationFilterFactoryTests.cs @@ -1,9 +1,9 @@ -using Bit.Core.Models.Data; -using Bit.Core.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Models.Data; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; public class IntegrationFilterFactoryTests { diff --git a/test/Core.Test/AdminConsole/Services/IntegrationFilterServiceTests.cs b/test/Core.Test/Dirt/Services/IntegrationFilterServiceTests.cs similarity index 84% rename from test/Core.Test/AdminConsole/Services/IntegrationFilterServiceTests.cs rename to test/Core.Test/Dirt/Services/IntegrationFilterServiceTests.cs index 4143469a4b..b7510b0e92 100644 --- a/test/Core.Test/AdminConsole/Services/IntegrationFilterServiceTests.cs +++ b/test/Core.Test/Dirt/Services/IntegrationFilterServiceTests.cs @@ -1,13 +1,13 @@ #nullable enable using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services.Implementations; using Bit.Core.Models.Data; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; public class IntegrationFilterServiceTests { @@ -42,6 +42,35 @@ public class IntegrationFilterServiceTests Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage)); } + [Theory, BitAutoData] + public void EvaluateFilterGroup_EqualsUserIdString_Matches(EventMessage eventMessage) + { + var userId = Guid.NewGuid(); + eventMessage.UserId = userId; + + var group = new IntegrationFilterGroup + { + AndOperator = true, + Rules = + [ + new() + { + Property = "UserId", + Operation = IntegrationFilterOperation.Equals, + Value = userId.ToString() + } + ] + }; + + var result = _service.EvaluateFilterGroup(group, eventMessage); + Assert.True(result); + + var jsonGroup = JsonSerializer.Serialize(group); + var roundtrippedGroup = JsonSerializer.Deserialize(jsonGroup); + Assert.NotNull(roundtrippedGroup); + Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage)); + } + [Theory, BitAutoData] public void EvaluateFilterGroup_EqualsUserId_DoesNotMatch(EventMessage eventMessage) { @@ -281,6 +310,45 @@ public class IntegrationFilterServiceTests Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage)); } + + [Theory, BitAutoData] + public void EvaluateFilterGroup_NestedGroups_AnyMatch(EventMessage eventMessage) + { + var id = Guid.NewGuid(); + var collectionId = Guid.NewGuid(); + eventMessage.UserId = id; + eventMessage.CollectionId = collectionId; + + var nestedGroup = new IntegrationFilterGroup + { + AndOperator = false, + Rules = + [ + new() { Property = "UserId", Operation = IntegrationFilterOperation.Equals, Value = id }, + new() + { + Property = "CollectionId", + Operation = IntegrationFilterOperation.In, + Value = new Guid?[] { Guid.NewGuid() } + } + ] + }; + + var topGroup = new IntegrationFilterGroup + { + AndOperator = false, + Groups = [nestedGroup] + }; + + var result = _service.EvaluateFilterGroup(topGroup, eventMessage); + Assert.True(result); + + var jsonGroup = JsonSerializer.Serialize(topGroup); + var roundtrippedGroup = JsonSerializer.Deserialize(jsonGroup); + Assert.NotNull(roundtrippedGroup); + Assert.True(_service.EvaluateFilterGroup(roundtrippedGroup, eventMessage)); + } + [Theory, BitAutoData] public void EvaluateFilterGroup_UnknownProperty_ReturnsFalse(EventMessage eventMessage) { diff --git a/test/Core.Test/Dirt/Services/IntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/IntegrationHandlerTests.cs new file mode 100644 index 0000000000..096fcc11bb --- /dev/null +++ b/test/Core.Test/Dirt/Services/IntegrationHandlerTests.cs @@ -0,0 +1,145 @@ +using System.Net; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Xunit; + +namespace Bit.Core.Test.Dirt.Services; + +public class IntegrationHandlerTests +{ + [Fact] + public async Task HandleAsync_ConvertsJsonToTypedIntegrationMessage() + { + var sut = new TestIntegrationHandler(); + var expected = new IntegrationMessage() + { + Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), + MessageId = "TestMessageId", + OrganizationId = "TestOrganizationId", + IntegrationType = IntegrationType.Webhook, + RenderedTemplate = "Template", + DelayUntilDate = null, + RetryCount = 0 + }; + + var result = await sut.HandleAsync(expected.ToJson()); + var typedResult = Assert.IsType>(result.Message); + + Assert.Equal(expected.MessageId, typedResult.MessageId); + Assert.Equal(expected.OrganizationId, typedResult.OrganizationId); + Assert.Equal(expected.Configuration, typedResult.Configuration); + Assert.Equal(expected.RenderedTemplate, typedResult.RenderedTemplate); + Assert.Equal(expected.IntegrationType, typedResult.IntegrationType); + } + + [Theory] + [InlineData(HttpStatusCode.Unauthorized)] + [InlineData(HttpStatusCode.Forbidden)] + public void ClassifyHttpStatusCode_AuthenticationFailed(HttpStatusCode code) + { + Assert.Equal( + IntegrationFailureCategory.AuthenticationFailed, + TestIntegrationHandler.Classify(code)); + } + + [Theory] + [InlineData(HttpStatusCode.NotFound)] + [InlineData(HttpStatusCode.Gone)] + [InlineData(HttpStatusCode.MovedPermanently)] + [InlineData(HttpStatusCode.TemporaryRedirect)] + [InlineData(HttpStatusCode.PermanentRedirect)] + public void ClassifyHttpStatusCode_ConfigurationError(HttpStatusCode code) + { + Assert.Equal( + IntegrationFailureCategory.ConfigurationError, + TestIntegrationHandler.Classify(code)); + } + + [Fact] + public void ClassifyHttpStatusCode_TooManyRequests_IsRateLimited() + { + Assert.Equal( + IntegrationFailureCategory.RateLimited, + TestIntegrationHandler.Classify(HttpStatusCode.TooManyRequests)); + } + + [Fact] + public void ClassifyHttpStatusCode_RequestTimeout_IsTransient() + { + Assert.Equal( + IntegrationFailureCategory.TransientError, + TestIntegrationHandler.Classify(HttpStatusCode.RequestTimeout)); + } + + [Theory] + [InlineData(HttpStatusCode.InternalServerError)] + [InlineData(HttpStatusCode.BadGateway)] + [InlineData(HttpStatusCode.GatewayTimeout)] + public void ClassifyHttpStatusCode_Common5xx_AreTransient(HttpStatusCode code) + { + Assert.Equal( + IntegrationFailureCategory.TransientError, + TestIntegrationHandler.Classify(code)); + } + + [Fact] + public void ClassifyHttpStatusCode_ServiceUnavailable_IsServiceUnavailable() + { + Assert.Equal( + IntegrationFailureCategory.ServiceUnavailable, + TestIntegrationHandler.Classify(HttpStatusCode.ServiceUnavailable)); + } + + [Fact] + public void ClassifyHttpStatusCode_NotImplemented_IsPermanentFailure() + { + Assert.Equal( + IntegrationFailureCategory.PermanentFailure, + TestIntegrationHandler.Classify(HttpStatusCode.NotImplemented)); + } + + [Fact] + public void FClassifyHttpStatusCode_Unhandled3xx_IsConfigurationError() + { + Assert.Equal( + IntegrationFailureCategory.ConfigurationError, + TestIntegrationHandler.Classify(HttpStatusCode.Found)); + } + + [Fact] + public void ClassifyHttpStatusCode_Unhandled4xx_IsConfigurationError() + { + Assert.Equal( + IntegrationFailureCategory.ConfigurationError, + TestIntegrationHandler.Classify(HttpStatusCode.BadRequest)); + } + + [Fact] + public void ClassifyHttpStatusCode_Unhandled5xx_IsServiceUnavailable() + { + Assert.Equal( + IntegrationFailureCategory.ServiceUnavailable, + TestIntegrationHandler.Classify(HttpStatusCode.HttpVersionNotSupported)); + } + + [Fact] + public void ClassifyHttpStatusCode_UnknownCode_DefaultsToServiceUnavailable() + { + // cast an out-of-range value to ensure default path is stable + Assert.Equal( + IntegrationFailureCategory.ServiceUnavailable, + TestIntegrationHandler.Classify((HttpStatusCode)799)); + } + + private class TestIntegrationHandler : IntegrationHandlerBase + { + public override Task HandleAsync( + IntegrationMessage message) + { + return Task.FromResult(IntegrationHandlerResult.Succeed(message: message)); + } + + public static IntegrationFailureCategory Classify(HttpStatusCode code) => ClassifyHttpStatusCode(code); + } +} diff --git a/test/Core.Test/Dirt/Services/OrganizationIntegrationConfigurationValidatorTests.cs b/test/Core.Test/Dirt/Services/OrganizationIntegrationConfigurationValidatorTests.cs new file mode 100644 index 0000000000..bee6a5182c --- /dev/null +++ b/test/Core.Test/Dirt/Services/OrganizationIntegrationConfigurationValidatorTests.cs @@ -0,0 +1,244 @@ +using System.Text.Json; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Enums; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services.Implementations; +using Xunit; + +namespace Bit.Core.Test.Dirt.Services; + +public class OrganizationIntegrationConfigurationValidatorTests +{ + private readonly OrganizationIntegrationConfigurationValidator _sut = new(); + + [Fact] + public void ValidateConfiguration_CloudBillingSyncIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.CloudBillingSync, configuration)); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyTemplate_ReturnsFalse(string? template) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = template + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config2)); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + public void ValidateConfiguration_EmptyNonNullConfiguration_ReturnsFalse(string? config) + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = config, + Template = "template" + }; + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_NullConfiguration_ReturnsTrue() + { + var config1 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Hec, config1)); + + var config2 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Datadog, config2)); + + var config3 = new OrganizationIntegrationConfiguration + { + Configuration = null, + Template = "template" + }; + Assert.True(_sut.ValidateConfiguration(IntegrationType.Teams, config3)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonConfiguration_ReturnsFalse() + { + var config = new OrganizationIntegrationConfiguration + { + Configuration = "{not valid json}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Slack, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Hec, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Datadog, config)); + Assert.False(_sut.ValidateConfiguration(IntegrationType.Teams, config)); + } + + [Fact] + public void ValidateConfiguration_InvalidJsonFilters_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://example.com"))), + Template = "template", + Filters = "{Not valid json}" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ScimIntegration_ReturnsFalse() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(IntegrationType.Scim, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration(ChannelId: "C12345")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidSlackConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new SlackIntegrationConfiguration("C12345")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Slack, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidNoAuthWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration(Uri: new Uri("https://localhost"))), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfiguration_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://localhost"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template" + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_ValidWebhookConfigurationWithFilters_ReturnsTrue() + { + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = JsonSerializer.Serialize(new WebhookIntegrationConfiguration( + Uri: new Uri("https://example.com"), + Scheme: "Bearer", + Token: "AUTH-TOKEN")), + Template = "template", + Filters = JsonSerializer.Serialize(new IntegrationFilterGroup() + { + AndOperator = true, + Rules = [ + new IntegrationFilterRule() + { + Operation = IntegrationFilterOperation.Equals, + Property = "CollectionId", + Value = Guid.NewGuid() + } + ], + Groups = [] + }) + }; + + Assert.True(_sut.ValidateConfiguration(IntegrationType.Webhook, configuration)); + } + + [Fact] + public void ValidateConfiguration_UnknownIntegrationType_ReturnsFalse() + { + var unknownType = (IntegrationType)999; + var configuration = new OrganizationIntegrationConfiguration + { + Configuration = "{}", + Template = "template" + }; + + Assert.False(_sut.ValidateConfiguration(unknownType, configuration)); + } +} diff --git a/test/Core.Test/AdminConsole/Services/RabbitMqEventListenerServiceTests.cs b/test/Core.Test/Dirt/Services/RabbitMqEventListenerServiceTests.cs similarity index 97% rename from test/Core.Test/AdminConsole/Services/RabbitMqEventListenerServiceTests.cs rename to test/Core.Test/Dirt/Services/RabbitMqEventListenerServiceTests.cs index 22e297a00d..560cf589ed 100644 --- a/test/Core.Test/AdminConsole/Services/RabbitMqEventListenerServiceTests.cs +++ b/test/Core.Test/Dirt/Services/RabbitMqEventListenerServiceTests.cs @@ -1,9 +1,10 @@ #nullable enable using System.Text.Json; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; using Bit.Core.Models.Data; -using Bit.Core.Services; +using Bit.Core.Test.Dirt.Models.Data.EventIntegrations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -13,7 +14,7 @@ using RabbitMQ.Client; using RabbitMQ.Client.Events; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class RabbitMqEventListenerServiceTests diff --git a/test/Core.Test/AdminConsole/Services/RabbitMqIntegrationListenerServiceTests.cs b/test/Core.Test/Dirt/Services/RabbitMqIntegrationListenerServiceTests.cs similarity index 90% rename from test/Core.Test/AdminConsole/Services/RabbitMqIntegrationListenerServiceTests.cs rename to test/Core.Test/Dirt/Services/RabbitMqIntegrationListenerServiceTests.cs index 5fcd121252..453a4e6527 100644 --- a/test/Core.Test/AdminConsole/Services/RabbitMqIntegrationListenerServiceTests.cs +++ b/test/Core.Test/Dirt/Services/RabbitMqIntegrationListenerServiceTests.cs @@ -1,8 +1,10 @@ #nullable enable using System.Text; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Services; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Core.Test.Dirt.Models.Data.EventIntegrations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -13,7 +15,7 @@ using RabbitMQ.Client; using RabbitMQ.Client.Events; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class RabbitMqIntegrationListenerServiceTests @@ -86,8 +88,10 @@ public class RabbitMqIntegrationListenerServiceTests new BasicProperties(), body: Encoding.UTF8.GetBytes(message.ToJson()) ); - var result = new IntegrationHandlerResult(false, message); - result.Retryable = false; + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.AuthenticationFailed, // NOT retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -105,7 +109,7 @@ public class RabbitMqIntegrationListenerServiceTests _logger.Received().Log( LogLevel.Warning, Arg.Any(), - Arg.Is(o => (o.ToString() ?? "").Contains("Non-retryable failure")), + Arg.Is(o => (o.ToString() ?? "").Contains("Integration failure - non-retryable.")), Arg.Any(), Arg.Any>()); @@ -133,8 +137,10 @@ public class RabbitMqIntegrationListenerServiceTests new BasicProperties(), body: Encoding.UTF8.GetBytes(message.ToJson()) ); - var result = new IntegrationHandlerResult(false, message); - result.Retryable = true; + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.TransientError, // Retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -151,7 +157,7 @@ public class RabbitMqIntegrationListenerServiceTests _logger.Received().Log( LogLevel.Warning, Arg.Any(), - Arg.Is(o => (o.ToString() ?? "").Contains("Max retry attempts reached")), + Arg.Is(o => (o.ToString() ?? "").Contains("Integration failure - max retries exceeded.")), Arg.Any(), Arg.Any>()); @@ -179,9 +185,10 @@ public class RabbitMqIntegrationListenerServiceTests new BasicProperties(), body: Encoding.UTF8.GetBytes(message.ToJson()) ); - var result = new IntegrationHandlerResult(false, message); - result.Retryable = true; - result.DelayUntilDate = _now.AddMinutes(1); + var result = IntegrationHandlerResult.Fail( + message: message, + category: IntegrationFailureCategory.TransientError, // Retryable + failureReason: "403"); _handler.HandleAsync(Arg.Any()).Returns(result); var expected = IntegrationMessage.FromJson(message.ToJson()); @@ -220,7 +227,7 @@ public class RabbitMqIntegrationListenerServiceTests new BasicProperties(), body: Encoding.UTF8.GetBytes(message.ToJson()) ); - var result = new IntegrationHandlerResult(true, message); + var result = IntegrationHandlerResult.Succeed(message); _handler.HandleAsync(Arg.Any()).Returns(result); await sutProvider.Sut.ProcessReceivedMessageAsync(eventArgs, cancellationToken); diff --git a/test/Core.Test/AdminConsole/Services/RepositoryEventWriteServiceTests.cs b/test/Core.Test/Dirt/Services/RepositoryEventWriteServiceTests.cs similarity index 100% rename from test/Core.Test/AdminConsole/Services/RepositoryEventWriteServiceTests.cs rename to test/Core.Test/Dirt/Services/RepositoryEventWriteServiceTests.cs diff --git a/test/Core.Test/Dirt/Services/SlackIntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/SlackIntegrationHandlerTests.cs new file mode 100644 index 0000000000..52bb7a03a4 --- /dev/null +++ b/test/Core.Test/Dirt/Services/SlackIntegrationHandlerTests.cs @@ -0,0 +1,140 @@ +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Slack; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Bit.Test.Common.Helpers; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Dirt.Services; + +[SutProviderCustomize] +public class SlackIntegrationHandlerTests +{ + private readonly ISlackService _slackService = Substitute.For(); + private readonly string _channelId = "C12345"; + private readonly string _token = "xoxb-test-token"; + + private SutProvider GetSutProvider() + { + return new SutProvider() + .SetDependency(_slackService) + .Create(); + } + + [Theory, BitAutoData] + public async Task HandleAsync_SuccessfulRequest_ReturnsSuccess(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token); + + _slackService.SendSlackMessageByChannelIdAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new SlackSendMessageResponse() { Ok = true, Channel = _channelId }); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.True(result.Success); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendSlackMessageByChannelIdAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_token)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)) + ); + } + + [Theory] + [InlineData("service_unavailable")] + [InlineData("ratelimited")] + [InlineData("rate_limited")] + [InlineData("internal_error")] + [InlineData("message_limit_exceeded")] + public async Task HandleAsync_FailedRetryableRequest_ReturnsFailureWithRetryable(string error) + { + var sutProvider = GetSutProvider(); + var message = new IntegrationMessage() + { + Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token), + MessageId = "MessageId", + RenderedTemplate = "Test Message" + }; + + _slackService.SendSlackMessageByChannelIdAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new SlackSendMessageResponse() { Ok = false, Channel = _channelId, Error = error }); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.True(result.Retryable); + Assert.NotNull(result.FailureReason); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendSlackMessageByChannelIdAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_token)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)) + ); + } + + [Theory] + [InlineData("access_denied")] + [InlineData("channel_not_found")] + [InlineData("token_expired")] + [InlineData("token_revoked")] + public async Task HandleAsync_FailedNonRetryableRequest_ReturnsNonRetryableFailure(string error) + { + var sutProvider = GetSutProvider(); + var message = new IntegrationMessage() + { + Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token), + MessageId = "MessageId", + RenderedTemplate = "Test Message" + }; + + _slackService.SendSlackMessageByChannelIdAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(new SlackSendMessageResponse() { Ok = false, Channel = _channelId, Error = error }); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.False(result.Retryable); + Assert.NotNull(result.FailureReason); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendSlackMessageByChannelIdAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_token)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)) + ); + } + + [Fact] + public async Task HandleAsync_NullResponse_ReturnsRetryableFailure() + { + var sutProvider = GetSutProvider(); + var message = new IntegrationMessage() + { + Configuration = new SlackIntegrationConfigurationDetails(_channelId, _token), + MessageId = "MessageId", + RenderedTemplate = "Test Message" + }; + + _slackService.SendSlackMessageByChannelIdAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .Returns((SlackSendMessageResponse?)null); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.True(result.Retryable); // Null response is classified as TransientError (retryable) + Assert.Equal("Slack response was null", result.FailureReason); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendSlackMessageByChannelIdAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_token)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)) + ); + } +} diff --git a/test/Core.Test/AdminConsole/Services/SlackServiceTests.cs b/test/Core.Test/Dirt/Services/SlackServiceTests.cs similarity index 70% rename from test/Core.Test/AdminConsole/Services/SlackServiceTests.cs rename to test/Core.Test/Dirt/Services/SlackServiceTests.cs index 2d0ca2433a..bbb505f5d3 100644 --- a/test/Core.Test/AdminConsole/Services/SlackServiceTests.cs +++ b/test/Core.Test/Dirt/Services/SlackServiceTests.cs @@ -3,7 +3,7 @@ using System.Net; using System.Text.Json; using System.Web; -using Bit.Core.Services; +using Bit.Core.Dirt.Services.Implementations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.MockedHttpClient; @@ -11,7 +11,7 @@ using NSubstitute; using Xunit; using GlobalSettings = Bit.Core.Settings.GlobalSettings; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class SlackServiceTests @@ -146,6 +146,27 @@ public class SlackServiceTests Assert.Empty(result); } + [Fact] + public async Task GetChannelIdAsync_NoChannelFound_ReturnsEmptyResult() + { + var emptyResponse = JsonSerializer.Serialize( + new + { + ok = true, + channels = Array.Empty(), + response_metadata = new { next_cursor = "" } + }); + + _handler.When(HttpMethod.Get) + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(emptyResponse)); + + var sutProvider = GetSutProvider(); + var result = await sutProvider.Sut.GetChannelIdAsync(_token, "general"); + + Assert.Empty(result); + } + [Fact] public async Task GetChannelIdAsync_ReturnsCorrectChannelId() { @@ -235,6 +256,32 @@ public class SlackServiceTests Assert.Equal(string.Empty, result); } + [Fact] + public async Task GetDmChannelByEmailAsync_ApiErrorUnparsableDmResponse_ReturnsEmptyString() + { + var sutProvider = GetSutProvider(); + var email = "user@example.com"; + var userId = "U12345"; + + var userResponse = new + { + ok = true, + user = new { id = userId } + }; + + _handler.When($"https://slack.com/api/users.lookupByEmail?email={email}") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(JsonSerializer.Serialize(userResponse))); + + _handler.When("https://slack.com/api/conversations.open") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent("NOT JSON")); + + var result = await sutProvider.Sut.GetDmChannelByEmailAsync(_token, email); + + Assert.Equal(string.Empty, result); + } + [Fact] public async Task GetDmChannelByEmailAsync_ApiErrorUserResponse_ReturnsEmptyString() { @@ -244,7 +291,7 @@ public class SlackServiceTests var userResponse = new { ok = false, - error = "An error occured" + error = "An error occurred" }; _handler.When($"https://slack.com/api/users.lookupByEmail?email={email}") @@ -256,6 +303,21 @@ public class SlackServiceTests Assert.Equal(string.Empty, result); } + [Fact] + public async Task GetDmChannelByEmailAsync_ApiErrorUnparsableUserResponse_ReturnsEmptyString() + { + var sutProvider = GetSutProvider(); + var email = "user@example.com"; + + _handler.When($"https://slack.com/api/users.lookupByEmail?email={email}") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent("Not JSON")); + + var result = await sutProvider.Sut.GetDmChannelByEmailAsync(_token, email); + + Assert.Equal(string.Empty, result); + } + [Fact] public void GetRedirectUrl_ReturnsCorrectUrl() { @@ -296,6 +358,18 @@ public class SlackServiceTests Assert.Equal("test-access-token", result); } + [Theory] + [InlineData("test-code", "")] + [InlineData("", "https://example.com/callback")] + [InlineData("", "")] + public async Task ObtainTokenViaOAuth_ReturnsEmptyString_WhenCodeOrRedirectUrlIsEmpty(string code, string redirectUrl) + { + var sutProvider = GetSutProvider(); + var result = await sutProvider.Sut.ObtainTokenViaOAuth(code, redirectUrl); + + Assert.Equal(string.Empty, result); + } + [Fact] public async Task ObtainTokenViaOAuth_ReturnsEmptyString_WhenErrorResponse() { @@ -329,18 +403,29 @@ public class SlackServiceTests } [Fact] - public async Task SendSlackMessageByChannelId_Sends_Correct_Message() + public async Task SendSlackMessageByChannelId_Success_ReturnsSuccessfulResponse() { var sutProvider = GetSutProvider(); var channelId = "C12345"; var message = "Hello, Slack!"; + var jsonResponse = JsonSerializer.Serialize(new + { + ok = true, + channel = channelId, + }); + _handler.When(HttpMethod.Post) .RespondWith(HttpStatusCode.OK) - .WithContent(new StringContent(string.Empty)); + .WithContent(new StringContent(jsonResponse)); - await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId); + var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId); + // Response was parsed correctly + Assert.NotNull(result); + Assert.True(result.Ok); + + // Request was sent correctly Assert.Single(_handler.CapturedRequests); var request = _handler.CapturedRequests[0]; Assert.NotNull(request); @@ -353,4 +438,62 @@ public class SlackServiceTests Assert.Equal(message, json.RootElement.GetProperty("text").GetString() ?? string.Empty); Assert.Equal(channelId, json.RootElement.GetProperty("channel").GetString() ?? string.Empty); } + + [Fact] + public async Task SendSlackMessageByChannelId_Failure_ReturnsErrorResponse() + { + var sutProvider = GetSutProvider(); + var channelId = "C12345"; + var message = "Hello, Slack!"; + + var jsonResponse = JsonSerializer.Serialize(new + { + ok = false, + channel = channelId, + error = "error" + }); + + _handler.When(HttpMethod.Post) + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(jsonResponse)); + + var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId); + + // Response was parsed correctly + Assert.NotNull(result); + Assert.False(result.Ok); + Assert.NotNull(result.Error); + } + + [Fact] + public async Task SendSlackMessageByChannelIdAsync_InvalidJson_ReturnsNull() + { + var sutProvider = GetSutProvider(); + var channelId = "C12345"; + var message = "Hello world!"; + + _handler.When(HttpMethod.Post) + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent("Not JSON")); + + var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId); + + Assert.Null(result); + } + + [Fact] + public async Task SendSlackMessageByChannelIdAsync_HttpServerError_ReturnsNull() + { + var sutProvider = GetSutProvider(); + var channelId = "C12345"; + var message = "Hello world!"; + + _handler.When(HttpMethod.Post) + .RespondWith(HttpStatusCode.InternalServerError) + .WithContent(new StringContent(string.Empty)); + + var result = await sutProvider.Sut.SendSlackMessageByChannelIdAsync(_token, message, channelId); + + Assert.Null(result); + } } diff --git a/test/Core.Test/Dirt/Services/TeamsIntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/TeamsIntegrationHandlerTests.cs new file mode 100644 index 0000000000..b608ed7ff8 --- /dev/null +++ b/test/Core.Test/Dirt/Services/TeamsIntegrationHandlerTests.cs @@ -0,0 +1,199 @@ +using System.Text.Json; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Bit.Test.Common.Helpers; +using Microsoft.Rest; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Xunit; + +namespace Bit.Core.Test.Dirt.Services; + +[SutProviderCustomize] +public class TeamsIntegrationHandlerTests +{ + private readonly ITeamsService _teamsService = Substitute.For(); + private readonly string _channelId = "C12345"; + private readonly Uri _serviceUrl = new Uri("http://localhost"); + + private SutProvider GetSutProvider() + { + return new SutProvider() + .SetDependency(_teamsService) + .Create(); + } + + [Theory, BitAutoData] + public async Task HandleAsync_SuccessfulRequest_ReturnsSuccess(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.True(result.Success); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_ArgumentException_ReturnsConfigurationError(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new ArgumentException("argument error")); + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category); + Assert.False(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_JsonException_ReturnsPermanentFailure(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new JsonException("JSON error")); + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.PermanentFailure, result.Category); + Assert.False(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_UriFormatException_ReturnsConfigurationError(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new UriFormatException("Bad URI")); + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.ConfigurationError, result.Category); + Assert.False(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_HttpExceptionForbidden_ReturnsAuthenticationFailed(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new HttpOperationException("Server error") + { + Response = new HttpResponseMessageWrapper( + new HttpResponseMessage(System.Net.HttpStatusCode.Forbidden), + "Forbidden" + ) + } + ); + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.AuthenticationFailed, result.Category); + Assert.False(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_HttpExceptionTooManyRequests_ReturnsRateLimited(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new HttpOperationException("Server error") + { + Response = new HttpResponseMessageWrapper( + new HttpResponseMessage(System.Net.HttpStatusCode.TooManyRequests), + "Too Many Requests" + ) + } + ); + + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.RateLimited, result.Category); + Assert.True(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } + + [Theory, BitAutoData] + public async Task HandleAsync_UnknownException_ReturnsTransientError(IntegrationMessage message) + { + var sutProvider = GetSutProvider(); + message.Configuration = new TeamsIntegrationConfigurationDetails(_channelId, _serviceUrl); + + sutProvider.GetDependency() + .SendMessageToChannelAsync(Arg.Any(), Arg.Any(), Arg.Any()) + .ThrowsAsync(new Exception("Unknown error")); + var result = await sutProvider.Sut.HandleAsync(message); + + Assert.False(result.Success); + Assert.Equal(IntegrationFailureCategory.TransientError, result.Category); + Assert.True(result.Retryable); + Assert.Equal(result.Message, message); + + await sutProvider.GetDependency().Received(1).SendMessageToChannelAsync( + Arg.Is(AssertHelper.AssertPropertyEqual(_serviceUrl)), + Arg.Is(AssertHelper.AssertPropertyEqual(_channelId)), + Arg.Is(AssertHelper.AssertPropertyEqual(message.RenderedTemplate)) + ); + } +} diff --git a/test/Core.Test/Dirt/Services/TeamsServiceTests.cs b/test/Core.Test/Dirt/Services/TeamsServiceTests.cs new file mode 100644 index 0000000000..61d20cc0af --- /dev/null +++ b/test/Core.Test/Dirt/Services/TeamsServiceTests.cs @@ -0,0 +1,289 @@ +#nullable enable + +using System.Net; +using System.Text.Json; +using System.Web; +using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Models.Data.Teams; +using Bit.Core.Dirt.Repositories; +using Bit.Core.Dirt.Services.Implementations; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Bit.Test.Common.MockedHttpClient; +using NSubstitute; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.Dirt.Services; + +[SutProviderCustomize] +public class TeamsServiceTests +{ + private readonly MockedHttpMessageHandler _handler; + private readonly HttpClient _httpClient; + + public TeamsServiceTests() + { + _handler = new MockedHttpMessageHandler(); + _httpClient = _handler.ToHttpClient(); + } + + private SutProvider GetSutProvider() + { + var clientFactory = Substitute.For(); + clientFactory.CreateClient(TeamsService.HttpClientName).Returns(_httpClient); + + var globalSettings = Substitute.For(); + globalSettings.Teams.LoginBaseUrl.Returns("https://login.example.com"); + globalSettings.Teams.GraphBaseUrl.Returns("https://graph.example.com"); + + return new SutProvider() + .SetDependency(clientFactory) + .SetDependency(globalSettings) + .Create(); + } + + [Fact] + public void GetRedirectUrl_ReturnsCorrectUrl() + { + var sutProvider = GetSutProvider(); + var clientId = sutProvider.GetDependency().Teams.ClientId; + var scopes = sutProvider.GetDependency().Teams.Scopes; + var callbackUrl = "https://example.com/callback"; + var state = Guid.NewGuid().ToString(); + var result = sutProvider.Sut.GetRedirectUrl(callbackUrl, state); + + var uri = new Uri(result); + var query = HttpUtility.ParseQueryString(uri.Query); + + Assert.Equal(clientId, query["client_id"]); + Assert.Equal(scopes, query["scope"]); + Assert.Equal(callbackUrl, query["redirect_uri"]); + Assert.Equal(state, query["state"]); + Assert.Equal("login.example.com", uri.Host); + Assert.Equal("/common/oauth2/v2.0/authorize", uri.AbsolutePath); + } + + [Fact] + public async Task ObtainTokenViaOAuth_Success_ReturnsAccessToken() + { + var sutProvider = GetSutProvider(); + var jsonResponse = JsonSerializer.Serialize(new + { + access_token = "test-access-token" + }); + + _handler.When("https://login.example.com/common/oauth2/v2.0/token") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(jsonResponse)); + + var result = await sutProvider.Sut.ObtainTokenViaOAuth("test-code", "https://example.com/callback"); + + Assert.Equal("test-access-token", result); + } + + [Theory] + [InlineData("test-code", "")] + [InlineData("", "https://example.com/callback")] + [InlineData("", "")] + public async Task ObtainTokenViaOAuth_CodeOrRedirectUrlIsEmpty_ReturnsEmptyString(string code, string redirectUrl) + { + var sutProvider = GetSutProvider(); + var result = await sutProvider.Sut.ObtainTokenViaOAuth(code, redirectUrl); + + Assert.Equal(string.Empty, result); + } + + [Fact] + public async Task ObtainTokenViaOAuth_HttpFailure_ReturnsEmptyString() + { + var sutProvider = GetSutProvider(); + _handler.When("https://login.example.com/common/oauth2/v2.0/token") + .RespondWith(HttpStatusCode.InternalServerError) + .WithContent(new StringContent(string.Empty)); + + var result = await sutProvider.Sut.ObtainTokenViaOAuth("test-code", "https://example.com/callback"); + + Assert.Equal(string.Empty, result); + } + + [Fact] + public async Task ObtainTokenViaOAuth_UnknownResponse_ReturnsEmptyString() + { + var sutProvider = GetSutProvider(); + + _handler.When("https://login.example.com/common/oauth2/v2.0/token") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent("Not an expected response")); + + var result = await sutProvider.Sut.ObtainTokenViaOAuth("test-code", "https://example.com/callback"); + + Assert.Equal(string.Empty, result); + } + + [Fact] + public async Task GetJoinedTeamsAsync_Success_ReturnsTeams() + { + var sutProvider = GetSutProvider(); + + var jsonResponse = JsonSerializer.Serialize(new + { + value = new[] + { + new { id = "team1", displayName = "Team One" }, + new { id = "team2", displayName = "Team Two" } + } + }); + + _handler.When("https://graph.example.com/me/joinedTeams") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(jsonResponse)); + + var result = await sutProvider.Sut.GetJoinedTeamsAsync("fake-access-token"); + + Assert.Equal(2, result.Count); + Assert.Contains(result, t => t is { Id: "team1", DisplayName: "Team One" }); + Assert.Contains(result, t => t is { Id: "team2", DisplayName: "Team Two" }); + } + + [Fact] + public async Task GetJoinedTeamsAsync_ServerReturnsEmpty_ReturnsEmptyList() + { + var sutProvider = GetSutProvider(); + + var jsonResponse = JsonSerializer.Serialize(new { value = (object?)null }); + + _handler.When("https://graph.example.com/me/joinedTeams") + .RespondWith(HttpStatusCode.OK) + .WithContent(new StringContent(jsonResponse)); + + var result = await sutProvider.Sut.GetJoinedTeamsAsync("fake-access-token"); + + Assert.NotNull(result); + Assert.Empty(result); + } + + [Fact] + public async Task GetJoinedTeamsAsync_ServerErrorCode_ReturnsEmptyList() + { + var sutProvider = GetSutProvider(); + + _handler.When("https://graph.example.com/me/joinedTeams") + .RespondWith(HttpStatusCode.Unauthorized) + .WithContent(new StringContent("Unauthorized")); + + var result = await sutProvider.Sut.GetJoinedTeamsAsync("fake-access-token"); + + Assert.NotNull(result); + Assert.Empty(result); + } + + [Theory, BitAutoData] + public async Task HandleIncomingAppInstall_Success_UpdatesTeamsIntegration( + OrganizationIntegration integration) + { + var sutProvider = GetSutProvider(); + var tenantId = Guid.NewGuid().ToString(); + var teamId = Guid.NewGuid().ToString(); + var conversationId = Guid.NewGuid().ToString(); + var serviceUrl = new Uri("https://localhost"); + var initiatedConfiguration = new TeamsIntegration(TenantId: tenantId, Teams: + [ + new TeamInfo() { Id = teamId, DisplayName = "test team", TenantId = tenantId }, + new TeamInfo() { Id = Guid.NewGuid().ToString(), DisplayName = "other team", TenantId = tenantId }, + new TeamInfo() { Id = Guid.NewGuid().ToString(), DisplayName = "third team", TenantId = tenantId } + ]); + integration.Configuration = JsonSerializer.Serialize(initiatedConfiguration); + + sutProvider.GetDependency() + .GetByTeamsConfigurationTenantIdTeamId(tenantId, teamId) + .Returns(integration); + + OrganizationIntegration? capturedIntegration = null; + await sutProvider.GetDependency() + .UpsertAsync(Arg.Do(x => capturedIntegration = x)); + + await sutProvider.Sut.HandleIncomingAppInstallAsync( + conversationId: conversationId, + serviceUrl: serviceUrl, + teamId: teamId, + tenantId: tenantId + ); + + await sutProvider.GetDependency().Received(1).GetByTeamsConfigurationTenantIdTeamId(tenantId, teamId); + Assert.NotNull(capturedIntegration); + var configuration = JsonSerializer.Deserialize(capturedIntegration.Configuration ?? string.Empty); + Assert.NotNull(configuration); + Assert.NotNull(configuration.ServiceUrl); + Assert.Equal(serviceUrl, configuration.ServiceUrl); + Assert.Equal(conversationId, configuration.ChannelId); + } + + [Fact] + public async Task HandleIncomingAppInstall_NoIntegrationMatched_DoesNothing() + { + var sutProvider = GetSutProvider(); + await sutProvider.Sut.HandleIncomingAppInstallAsync( + conversationId: "conversationId", + serviceUrl: new Uri("https://localhost"), + teamId: "teamId", + tenantId: "tenantId" + ); + + await sutProvider.GetDependency().Received(1).GetByTeamsConfigurationTenantIdTeamId("tenantId", "teamId"); + await sutProvider.GetDependency().DidNotReceive().UpsertAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleIncomingAppInstall_MatchedIntegrationAlreadySetup_DoesNothing( + OrganizationIntegration integration) + { + var sutProvider = GetSutProvider(); + var tenantId = Guid.NewGuid().ToString(); + var teamId = Guid.NewGuid().ToString(); + var initiatedConfiguration = new TeamsIntegration( + TenantId: tenantId, + Teams: [new TeamInfo() { Id = teamId, DisplayName = "test team", TenantId = tenantId }], + ChannelId: "ChannelId", + ServiceUrl: new Uri("https://localhost") + ); + integration.Configuration = JsonSerializer.Serialize(initiatedConfiguration); + + sutProvider.GetDependency() + .GetByTeamsConfigurationTenantIdTeamId(tenantId, teamId) + .Returns(integration); + + await sutProvider.Sut.HandleIncomingAppInstallAsync( + conversationId: "conversationId", + serviceUrl: new Uri("https://localhost"), + teamId: teamId, + tenantId: tenantId + ); + + await sutProvider.GetDependency().Received(1).GetByTeamsConfigurationTenantIdTeamId(tenantId, teamId); + await sutProvider.GetDependency().DidNotReceive().UpsertAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleIncomingAppInstall_MatchedIntegrationWithMissingConfiguration_DoesNothing( + OrganizationIntegration integration) + { + var sutProvider = GetSutProvider(); + integration.Configuration = null; + + sutProvider.GetDependency() + .GetByTeamsConfigurationTenantIdTeamId("tenantId", "teamId") + .Returns(integration); + + await sutProvider.Sut.HandleIncomingAppInstallAsync( + conversationId: "conversationId", + serviceUrl: new Uri("https://localhost"), + teamId: "teamId", + tenantId: "tenantId" + ); + + await sutProvider.GetDependency().Received(1).GetByTeamsConfigurationTenantIdTeamId("tenantId", "teamId"); + await sutProvider.GetDependency().DidNotReceive().UpsertAsync(Arg.Any()); + } +} diff --git a/test/Core.Test/AdminConsole/Services/WebhookIntegrationHandlerTests.cs b/test/Core.Test/Dirt/Services/WebhookIntegrationHandlerTests.cs similarity index 97% rename from test/Core.Test/AdminConsole/Services/WebhookIntegrationHandlerTests.cs rename to test/Core.Test/Dirt/Services/WebhookIntegrationHandlerTests.cs index 53a3598d47..5d8bbfe439 100644 --- a/test/Core.Test/AdminConsole/Services/WebhookIntegrationHandlerTests.cs +++ b/test/Core.Test/Dirt/Services/WebhookIntegrationHandlerTests.cs @@ -1,7 +1,7 @@ using System.Net; using System.Net.Http.Headers; -using Bit.Core.AdminConsole.Models.Data.EventIntegrations; -using Bit.Core.Services; +using Bit.Core.Dirt.Models.Data.EventIntegrations; +using Bit.Core.Dirt.Services.Implementations; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Bit.Test.Common.Helpers; @@ -10,7 +10,7 @@ using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Services; +namespace Bit.Core.Test.Dirt.Services; [SutProviderCustomize] public class WebhookIntegrationHandlerTests @@ -51,7 +51,7 @@ public class WebhookIntegrationHandlerTests Assert.True(result.Success); Assert.Equal(result.Message, message); - Assert.Empty(result.FailureReason); + Assert.Null(result.FailureReason); sutProvider.GetDependency().Received(1).CreateClient( Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName)) @@ -79,7 +79,7 @@ public class WebhookIntegrationHandlerTests Assert.True(result.Success); Assert.Equal(result.Message, message); - Assert.Empty(result.FailureReason); + Assert.Null(result.FailureReason); sutProvider.GetDependency().Received(1).CreateClient( Arg.Is(AssertHelper.AssertPropertyEqual(WebhookIntegrationHandler.HttpClientName)) diff --git a/test/Core.Test/KeyManagement/Authorization/KeyConnectorAuthorizationHandlerTests.cs b/test/Core.Test/KeyManagement/Authorization/KeyConnectorAuthorizationHandlerTests.cs new file mode 100644 index 0000000000..fb774a78ac --- /dev/null +++ b/test/Core.Test/KeyManagement/Authorization/KeyConnectorAuthorizationHandlerTests.cs @@ -0,0 +1,151 @@ +using System.Security.Claims; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.KeyManagement.Authorization; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Authorization; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.KeyManagement.Authorization; + +[SutProviderCustomize] +public class KeyConnectorAuthorizationHandlerTests +{ + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UserCanUseKeyConnector_Success( + User user, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = false; + sutProvider.GetDependency().Organizations + .Returns(new List()); + + var requirement = KeyConnectorOperations.Use; + var context = new AuthorizationHandlerContext([requirement], claimsPrincipal, user); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UserAlreadyUsesKeyConnector_Fails( + User user, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = true; + sutProvider.GetDependency().Organizations + .Returns(new List()); + + var requirement = KeyConnectorOperations.Use; + var context = new AuthorizationHandlerContext([requirement], claimsPrincipal, user); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.False(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UserIsOwner_Fails( + User user, + Guid organizationId, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = false; + var organizations = new List + { + new() { Id = organizationId, Type = OrganizationUserType.Owner } + }; + sutProvider.GetDependency().Organizations.Returns(organizations); + + var requirement = KeyConnectorOperations.Use; + var context = new AuthorizationHandlerContext([requirement], claimsPrincipal, user); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.False(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UserIsAdmin_Fails( + User user, + Guid organizationId, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = false; + var organizations = new List + { + new() { Id = organizationId, Type = OrganizationUserType.Admin } + }; + sutProvider.GetDependency().Organizations.Returns(organizations); + + var requirement = KeyConnectorOperations.Use; + var context = new AuthorizationHandlerContext([requirement], claimsPrincipal, user); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.False(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UserIsRegularMember_Success( + User user, + Guid organizationId, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = false; + var organizations = new List + { + new() { Id = organizationId, Type = OrganizationUserType.User } + }; + sutProvider.GetDependency().Organizations.Returns(organizations); + + var requirement = KeyConnectorOperations.Use; + var context = new AuthorizationHandlerContext([requirement], claimsPrincipal, user); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_UnsupportedRequirement_ThrowsArgumentException( + User user, + ClaimsPrincipal claimsPrincipal, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = false; + sutProvider.GetDependency().Organizations + .Returns(new List()); + + var unsupportedRequirement = new KeyConnectorOperationsRequirement("UnsupportedOperation"); + var context = new AuthorizationHandlerContext([unsupportedRequirement], claimsPrincipal, user); + + // Act & Assert + await Assert.ThrowsAsync(() => sutProvider.Sut.HandleAsync(context)); + } +} diff --git a/test/Core.Test/KeyManagement/Commands/SetKeyConnectorKeyCommandTests.cs b/test/Core.Test/KeyManagement/Commands/SetKeyConnectorKeyCommandTests.cs new file mode 100644 index 0000000000..74f76f368b --- /dev/null +++ b/test/Core.Test/KeyManagement/Commands/SetKeyConnectorKeyCommandTests.cs @@ -0,0 +1,125 @@ +using System.Security.Claims; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Commands; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.KeyManagement.Commands; + +[SutProviderCustomize] +public class SetKeyConnectorKeyCommandTests +{ + + [Theory, BitAutoData] + public async Task SetKeyConnectorKeyForUserAsync_Success_SetsAccountKeys( + User user, + KeyConnectorKeysData data, + SutProvider sutProvider) + { + // Set up valid V2 encryption data + if (data.AccountKeys!.SignatureKeyPair != null) + { + data.AccountKeys.SignatureKeyPair.SignatureAlgorithm = "ed25519"; + } + + var expectedAccountKeysData = data.AccountKeys.ToAccountKeysData(); + + // Arrange + user.UsesKeyConnector = false; + var currentContext = sutProvider.GetDependency(); + var httpContext = Substitute.For(); + httpContext.User.Returns(new ClaimsPrincipal()); + currentContext.HttpContext.Returns(httpContext); + + sutProvider.GetDependency() + .AuthorizeAsync(Arg.Any(), user, Arg.Any>()) + .Returns(AuthorizationResult.Success()); + + var userRepository = sutProvider.GetDependency(); + var mockUpdateUserData = Substitute.For(); + userRepository.SetKeyConnectorUserKey(user.Id, data.KeyConnectorKeyWrappedUserKey!) + .Returns(mockUpdateUserData); + + // Act + await sutProvider.Sut.SetKeyConnectorKeyForUserAsync(user, data); + + // Assert + + userRepository + .Received(1) + .SetKeyConnectorUserKey(user.Id, data.KeyConnectorKeyWrappedUserKey); + + await userRepository + .Received(1) + .SetV2AccountCryptographicStateAsync( + user.Id, + Arg.Is(data => + data.PublicKeyEncryptionKeyPairData.PublicKey == expectedAccountKeysData.PublicKeyEncryptionKeyPairData.PublicKey && + data.PublicKeyEncryptionKeyPairData.WrappedPrivateKey == expectedAccountKeysData.PublicKeyEncryptionKeyPairData.WrappedPrivateKey && + data.PublicKeyEncryptionKeyPairData.SignedPublicKey == expectedAccountKeysData.PublicKeyEncryptionKeyPairData.SignedPublicKey && + data.SignatureKeyPairData!.SignatureAlgorithm == expectedAccountKeysData.SignatureKeyPairData!.SignatureAlgorithm && + data.SignatureKeyPairData.WrappedSigningKey == expectedAccountKeysData.SignatureKeyPairData.WrappedSigningKey && + data.SignatureKeyPairData.VerifyingKey == expectedAccountKeysData.SignatureKeyPairData.VerifyingKey && + data.SecurityStateData!.SecurityState == expectedAccountKeysData.SecurityStateData!.SecurityState && + data.SecurityStateData.SecurityVersion == expectedAccountKeysData.SecurityStateData.SecurityVersion), + Arg.Is>(actions => + actions.Count() == 1 && actions.First() == mockUpdateUserData)); + + await sutProvider.GetDependency() + .Received(1) + .LogUserEventAsync(user.Id, EventType.User_MigratedKeyToKeyConnector); + + await sutProvider.GetDependency() + .Received(1) + .AcceptOrgUserByOrgSsoIdAsync(data.OrgIdentifier, user, sutProvider.GetDependency()); + } + + [Theory, BitAutoData] + public async Task SetKeyConnectorKeyForUserAsync_UserCantUseKeyConnector_ThrowsException( + User user, + KeyConnectorKeysData data, + SutProvider sutProvider) + { + // Arrange + user.UsesKeyConnector = true; + var currentContext = sutProvider.GetDependency(); + var httpContext = Substitute.For(); + httpContext.User.Returns(new ClaimsPrincipal()); + currentContext.HttpContext.Returns(httpContext); + + sutProvider.GetDependency() + .AuthorizeAsync(Arg.Any(), user, Arg.Any>()) + .Returns(AuthorizationResult.Failed()); + + // Act & Assert + await Assert.ThrowsAsync( + () => sutProvider.Sut.SetKeyConnectorKeyForUserAsync(user, data)); + + sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetKeyConnectorUserKey(Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SetV2AccountCryptographicStateAsync(Arg.Any(), Arg.Any(), Arg.Any>()); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogUserEventAsync(Arg.Any(), Arg.Any()); + + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .AcceptOrgUserByOrgSsoIdAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } +} diff --git a/test/Core.Test/KeyManagement/Kdf/ChangeKdfCommandTests.cs b/test/Core.Test/KeyManagement/Kdf/ChangeKdfCommandTests.cs index 02e04b9ce9..991935b928 100644 --- a/test/Core.Test/KeyManagement/Kdf/ChangeKdfCommandTests.cs +++ b/test/Core.Test/KeyManagement/Kdf/ChangeKdfCommandTests.cs @@ -1,9 +1,11 @@ #nullable enable using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf.Implementations; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; @@ -21,16 +23,12 @@ public class ChangeKdfCommandTests [BitAutoData] public async Task ChangeKdfAsync_ChangesKdfAsync(SutProvider sutProvider, User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); - sutProvider.GetDependency().UpdatePasswordHash(Arg.Any(), Arg.Any()).Returns(Task.FromResult(IdentityResult.Success)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); + sutProvider.GetDependency().UpdatePasswordHash(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); - var kdf = new KdfSettings - { - KdfType = Enums.KdfType.Argon2id, - Iterations = 4, - Memory = 512, - Parallelism = 4 - }; + var kdf = new KdfSettings { KdfType = Enums.KdfType.Argon2id, Iterations = 4, Memory = 512, Parallelism = 4 }; var authenticationData = new MasterPasswordAuthenticationData { Kdf = kdf, @@ -59,13 +57,7 @@ public class ChangeKdfCommandTests [BitAutoData] public async Task ChangeKdfAsync_UserIsNull_ThrowsArgumentNullException(SutProvider sutProvider) { - var kdf = new KdfSettings - { - KdfType = Enums.KdfType.Argon2id, - Iterations = 4, - Memory = 512, - Parallelism = 4 - }; + var kdf = new KdfSettings { KdfType = Enums.KdfType.Argon2id, Iterations = 4, Memory = 512, Parallelism = 4 }; var authenticationData = new MasterPasswordAuthenticationData { Kdf = kdf, @@ -85,17 +77,13 @@ public class ChangeKdfCommandTests [Theory] [BitAutoData] - public async Task ChangeKdfAsync_WrongPassword_ReturnsPasswordMismatch(SutProvider sutProvider, User user) + public async Task ChangeKdfAsync_WrongPassword_ReturnsPasswordMismatch(SutProvider sutProvider, + User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(false)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(false)); - var kdf = new KdfSettings - { - KdfType = Enums.KdfType.Argon2id, - Iterations = 4, - Memory = 512, - Parallelism = 4 - }; + var kdf = new KdfSettings { KdfType = Enums.KdfType.Argon2id, Iterations = 4, Memory = 512, Parallelism = 4 }; var authenticationData = new MasterPasswordAuthenticationData { Kdf = kdf, @@ -116,7 +104,9 @@ public class ChangeKdfCommandTests [Theory] [BitAutoData] - public async Task ChangeKdfAsync_WithAuthenticationAndUnlockData_UpdatesUserCorrectly(SutProvider sutProvider, User user) + public async Task + ChangeKdfAsync_WithAuthenticationAndUnlockDataAndNoLogoutOnKdfChangeFeatureFlagOff_UpdatesUserCorrectlyAndLogsOut( + SutProvider sutProvider, User user) { var constantKdf = new KdfSettings { @@ -137,8 +127,12 @@ public class ChangeKdfCommandTests MasterKeyWrappedUserKey = "new-wrapped-key", Salt = user.GetMasterPasswordSalt() }; - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); - sutProvider.GetDependency().UpdatePasswordHash(Arg.Any(), Arg.Any()).Returns(Task.FromResult(IdentityResult.Success)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); + sutProvider.GetDependency() + .UpdatePasswordHash(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); await sutProvider.Sut.ChangeKdfAsync(user, "masterPassword", authenticationData, unlockData); @@ -150,17 +144,79 @@ public class ChangeKdfCommandTests && u.KdfParallelism == constantKdf.Parallelism && u.Key == "new-wrapped-key" )); + await sutProvider.GetDependency().Received(1).UpdatePasswordHash(user, + authenticationData.MasterPasswordAuthenticationHash, validatePassword: true, refreshStamp: true); + await sutProvider.GetDependency().Received(1).PushLogOutAsync(user.Id); + sutProvider.GetDependency().Received(1).IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange); } [Theory] [BitAutoData] - public async Task ChangeKdfAsync_KdfNotEqualBetweenAuthAndUnlock_ThrowsBadRequestException(SutProvider sutProvider, User user) + public async Task + ChangeKdfAsync_WithAuthenticationAndUnlockDataAndNoLogoutOnKdfChangeFeatureFlagOn_UpdatesUserCorrectlyAndDoesNotLogOut( + SutProvider sutProvider, User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + var constantKdf = new KdfSettings + { + KdfType = Enums.KdfType.Argon2id, + Iterations = 5, + Memory = 1024, + Parallelism = 4 + }; + var authenticationData = new MasterPasswordAuthenticationData + { + Kdf = constantKdf, + MasterPasswordAuthenticationHash = "new-auth-hash", + Salt = user.GetMasterPasswordSalt() + }; + var unlockData = new MasterPasswordUnlockData + { + Kdf = constantKdf, + MasterKeyWrappedUserKey = "new-wrapped-key", + Salt = user.GetMasterPasswordSalt() + }; + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); + sutProvider.GetDependency() + .UpdatePasswordHash(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + + await sutProvider.Sut.ChangeKdfAsync(user, "masterPassword", authenticationData, unlockData); + + await sutProvider.GetDependency().Received(1).ReplaceAsync(Arg.Is(u => + u.Id == user.Id + && u.Kdf == constantKdf.KdfType + && u.KdfIterations == constantKdf.Iterations + && u.KdfMemory == constantKdf.Memory + && u.KdfParallelism == constantKdf.Parallelism + && u.Key == "new-wrapped-key" + )); + await sutProvider.GetDependency().Received(1).UpdatePasswordHash(user, + authenticationData.MasterPasswordAuthenticationHash, validatePassword: true, refreshStamp: false); + await sutProvider.GetDependency().Received(1) + .PushLogOutAsync(user.Id, false, PushNotificationLogOutReason.KdfChange); + await sutProvider.GetDependency().Received(1).PushSyncSettingsAsync(user.Id); + sutProvider.GetDependency().Received(1).IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange); + } + + [Theory] + [BitAutoData] + public async Task ChangeKdfAsync_KdfNotEqualBetweenAuthAndUnlock_ThrowsBadRequestException( + SutProvider sutProvider, User user) + { + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); var authenticationData = new MasterPasswordAuthenticationData { - Kdf = new KdfSettings { KdfType = Enums.KdfType.Argon2id, Iterations = 4, Memory = 512, Parallelism = 4 }, + Kdf = new KdfSettings + { + KdfType = Enums.KdfType.Argon2id, + Iterations = 4, + Memory = 512, + Parallelism = 4 + }, MasterPasswordAuthenticationHash = "new-auth-hash", Salt = user.GetMasterPasswordSalt() }; @@ -176,9 +232,11 @@ public class ChangeKdfCommandTests [Theory] [BitAutoData] - public async Task ChangeKdfAsync_AuthDataSaltMismatch_Throws(SutProvider sutProvider, User user, KdfSettings kdf) + public async Task ChangeKdfAsync_AuthDataSaltMismatch_Throws(SutProvider sutProvider, User user, + KdfSettings kdf) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); var authenticationData = new MasterPasswordAuthenticationData { @@ -192,15 +250,17 @@ public class ChangeKdfCommandTests MasterKeyWrappedUserKey = "new-wrapped-key", Salt = user.GetMasterPasswordSalt() }; - await Assert.ThrowsAsync(async () => + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ChangeKdfAsync(user, "masterPassword", authenticationData, unlockData)); } [Theory] [BitAutoData] - public async Task ChangeKdfAsync_UnlockDataSaltMismatch_Throws(SutProvider sutProvider, User user, KdfSettings kdf) + public async Task ChangeKdfAsync_UnlockDataSaltMismatch_Throws(SutProvider sutProvider, User user, + KdfSettings kdf) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); var authenticationData = new MasterPasswordAuthenticationData { @@ -214,25 +274,22 @@ public class ChangeKdfCommandTests MasterKeyWrappedUserKey = "new-wrapped-key", Salt = "different-salt" }; - await Assert.ThrowsAsync(async () => + await Assert.ThrowsAsync(async () => await sutProvider.Sut.ChangeKdfAsync(user, "masterPassword", authenticationData, unlockData)); } [Theory] [BitAutoData] - public async Task ChangeKdfAsync_UpdatePasswordHashFails_ReturnsFailure(SutProvider sutProvider, User user) + public async Task ChangeKdfAsync_UpdatePasswordHashFails_ReturnsFailure(SutProvider sutProvider, + User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); var failedResult = IdentityResult.Failed(new IdentityError { Code = "TestFail", Description = "Test fail" }); - sutProvider.GetDependency().UpdatePasswordHash(Arg.Any(), Arg.Any()).Returns(Task.FromResult(failedResult)); + sutProvider.GetDependency().UpdatePasswordHash(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(failedResult)); - var kdf = new KdfSettings - { - KdfType = Enums.KdfType.Argon2id, - Iterations = 4, - Memory = 512, - Parallelism = 4 - }; + var kdf = new KdfSettings { KdfType = Enums.KdfType.Argon2id, Iterations = 4, Memory = 512, Parallelism = 4 }; var authenticationData = new MasterPasswordAuthenticationData { Kdf = kdf, @@ -253,9 +310,11 @@ public class ChangeKdfCommandTests [Theory] [BitAutoData] - public async Task ChangeKdfAsync_InvalidKdfSettings_ThrowsBadRequestException(SutProvider sutProvider, User user) + public async Task ChangeKdfAsync_InvalidKdfSettings_ThrowsBadRequestException( + SutProvider sutProvider, User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); // Create invalid KDF settings (iterations too low for PBKDF2) var invalidKdf = new KdfSettings @@ -287,9 +346,11 @@ public class ChangeKdfCommandTests [Theory] [BitAutoData] - public async Task ChangeKdfAsync_InvalidArgon2Settings_ThrowsBadRequestException(SutProvider sutProvider, User user) + public async Task ChangeKdfAsync_InvalidArgon2Settings_ThrowsBadRequestException( + SutProvider sutProvider, User user) { - sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(true)); + sutProvider.GetDependency().CheckPasswordAsync(Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(true)); // Create invalid Argon2 KDF settings (memory too high) var invalidKdf = new KdfSettings @@ -318,5 +379,4 @@ public class ChangeKdfCommandTests Assert.Equal("KDF settings are invalid.", exception.Message); } - } diff --git a/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs b/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs new file mode 100644 index 0000000000..612d63f289 --- /dev/null +++ b/test/Core.Test/KeyManagement/Queries/KeyConnectorConfirmationDetailsQueryTests.cs @@ -0,0 +1,86 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Queries; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.KeyManagement.Queries; + +[SutProviderCustomize] +public class KeyConnectorConfirmationDetailsQueryTests +{ + [Theory] + [BitAutoData] + public async Task Run_OrganizationNotFound_Throws(SutProvider sutProvider, + Guid userId, string orgSsoIdentifier) + { + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .ReceivedWithAnyArgs(0) + .GetByOrganizationAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task Run_OrganizationNotKeyConnector_Throws( + SutProvider sutProvider, + Guid userId, string orgSsoIdentifier, Organization org) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = false; + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .ReceivedWithAnyArgs(0) + .GetByOrganizationAsync(Arg.Any(), Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task Run_OrganizationUserNotFound_Throws(SutProvider sutProvider, + Guid userId, string orgSsoIdentifier + , Organization org) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = true; + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + sutProvider.GetDependency() + .GetByOrganizationAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(null)); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Run(orgSsoIdentifier, userId)); + + await sutProvider.GetDependency() + .Received(1) + .GetByOrganizationAsync(org.Id, userId); + } + + [Theory] + [BitAutoData] + public async Task Run_Success(SutProvider sutProvider, Guid userId, + string orgSsoIdentifier + , Organization org, OrganizationUser orgUser) + { + org.Identifier = orgSsoIdentifier; + org.UseKeyConnector = true; + orgUser.OrganizationId = org.Id; + orgUser.UserId = userId; + + sutProvider.GetDependency().GetByIdentifierAsync(orgSsoIdentifier).Returns(org); + sutProvider.GetDependency().GetByOrganizationAsync(org.Id, userId) + .Returns(orgUser); + + var result = await sutProvider.Sut.Run(orgSsoIdentifier, userId); + + Assert.Equal(org.Name, result.OrganizationName); + await sutProvider.GetDependency() + .Received(1) + .GetByOrganizationAsync(org.Id, userId); + } +} diff --git a/test/Core.Test/KeyManagement/Queries/UserAccountKeysQuery.cs b/test/Core.Test/KeyManagement/Queries/UserAccountKeysQuery.cs new file mode 100644 index 0000000000..f79217acba --- /dev/null +++ b/test/Core.Test/KeyManagement/Queries/UserAccountKeysQuery.cs @@ -0,0 +1,43 @@ +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries; +using Bit.Core.KeyManagement.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.KeyManagement.Queries; + +[SutProviderCustomize] +public class UserAccountKeysQueryTests +{ + [Theory, BitAutoData] + public async Task V1User_Success(SutProvider sutProvider, User user) + { + var result = await sutProvider.Sut.Run(user); + Assert.Equal(user.GetPublicKeyEncryptionKeyPair().PublicKey, result.PublicKeyEncryptionKeyPairData.PublicKey); + Assert.Equal(user.GetPublicKeyEncryptionKeyPair().WrappedPrivateKey, result.PublicKeyEncryptionKeyPairData.WrappedPrivateKey); + } + + [Theory, BitAutoData] + public async Task V2User_Success(SutProvider sutProvider, User user) + { + user.SecurityState = "v2"; + user.SecurityVersion = 2; + var signatureKeyPairRepository = sutProvider.GetDependency(); + signatureKeyPairRepository.GetByUserIdAsync(user.Id).Returns(new SignatureKeyPairData(Core.KeyManagement.Enums.SignatureAlgorithm.Ed25519, "wrappedSigningKey", "verifyingKey")); + var result = await sutProvider.Sut.Run(user); + Assert.Equal(user.GetPublicKeyEncryptionKeyPair().PublicKey, result.PublicKeyEncryptionKeyPairData.PublicKey); + Assert.Equal(user.GetPublicKeyEncryptionKeyPair().WrappedPrivateKey, result.PublicKeyEncryptionKeyPairData.WrappedPrivateKey); + Assert.Equal(user.GetPublicKeyEncryptionKeyPair().SignedPublicKey, result.PublicKeyEncryptionKeyPairData.SignedPublicKey); + + Assert.NotNull(result.SignatureKeyPairData); + Assert.Equal("wrappedSigningKey", result.SignatureKeyPairData.WrappedSigningKey); + Assert.Equal("verifyingKey", result.SignatureKeyPairData.VerifyingKey); + + Assert.Equal(user.SecurityState, result.SecurityStateData.SecurityState); + Assert.Equal(user.GetSecurityVersion(), result.SecurityStateData.SecurityVersion); + } + +} diff --git a/test/Core.Test/KeyManagement/UserKey/RotateUserAccountKeysCommandTests.cs b/test/Core.Test/KeyManagement/UserKey/RotateUserAccountKeysCommandTests.cs index e677814fc1..f4d1fc5c94 100644 --- a/test/Core.Test/KeyManagement/UserKey/RotateUserAccountKeysCommandTests.cs +++ b/test/Core.Test/KeyManagement/UserKey/RotateUserAccountKeysCommandTests.cs @@ -1,11 +1,18 @@ using Bit.Core.Entities; +using Bit.Core.KeyManagement.Enums; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.KeyManagement.UserKey.Implementations; using Bit.Core.Services; +using Bit.Core.Tools.Entities; +using Bit.Core.Tools.Repositories; +using Bit.Core.Vault.Entities; +using Bit.Core.Vault.Repositories; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Identity; using NSubstitute; +using NSubstitute.ReturnsExtensions; using Xunit; namespace Bit.Core.Test.KeyManagement.UserKey; @@ -14,7 +21,7 @@ namespace Bit.Core.Test.KeyManagement.UserKey; public class RotateUserAccountKeysCommandTests { [Theory, BitAutoData] - public async Task RejectsWrongOldMasterPassword(SutProvider sutProvider, User user, + public async Task RotateUserAccountKeysAsync_WrongOldMasterPassword_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) { user.Email = model.MasterPasswordUnlockData.Email; @@ -25,41 +32,38 @@ public class RotateUserAccountKeysCommandTests Assert.NotEqual(IdentityResult.Success, result); } + [Theory, BitAutoData] - public async Task ThrowsWhenUserIsNull(SutProvider sutProvider, + public async Task RotateUserAccountKeysAsync_UserIsNull_Rejects(SutProvider sutProvider, RotateUserAccountKeysData model) { await Assert.ThrowsAsync(async () => await sutProvider.Sut.RotateUserAccountKeysAsync(null, model)); } + [Theory, BitAutoData] - public async Task RejectsEmailChange(SutProvider sutProvider, User user, + public async Task RotateUserAccountKeysAsync_EmailChange_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) { - user.Kdf = Enums.KdfType.Argon2id; - user.KdfIterations = 3; - user.KdfMemory = 64; - user.KdfParallelism = 4; + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); model.MasterPasswordUnlockData.Email = user.Email + ".different-domain"; - model.MasterPasswordUnlockData.KdfType = Enums.KdfType.Argon2id; - model.MasterPasswordUnlockData.KdfIterations = 3; - model.MasterPasswordUnlockData.KdfMemory = 64; - model.MasterPasswordUnlockData.KdfParallelism = 4; sutProvider.GetDependency().CheckPasswordAsync(user, model.OldMasterKeyAuthenticationHash) .Returns(true); await Assert.ThrowsAsync(async () => await sutProvider.Sut.RotateUserAccountKeysAsync(user, model)); } [Theory, BitAutoData] - public async Task RejectsKdfChange(SutProvider sutProvider, User user, + public async Task RotateUserAccountKeysAsync_KdfChange_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) { - user.Kdf = Enums.KdfType.Argon2id; - user.KdfIterations = 3; - user.KdfMemory = 64; - user.KdfParallelism = 4; + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); - model.MasterPasswordUnlockData.Email = user.Email; model.MasterPasswordUnlockData.KdfType = Enums.KdfType.PBKDF2_SHA256; model.MasterPasswordUnlockData.KdfIterations = 600000; model.MasterPasswordUnlockData.KdfMemory = null; @@ -71,22 +75,15 @@ public class RotateUserAccountKeysCommandTests [Theory, BitAutoData] - public async Task RejectsPublicKeyChange(SutProvider sutProvider, User user, + public async Task RotateUserAccountKeysAsync_PublicKeyChange_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) { - user.PublicKey = "old-public"; - user.Kdf = Enums.KdfType.Argon2id; - user.KdfIterations = 3; - user.KdfMemory = 64; - user.KdfParallelism = 4; - - model.AccountPublicKey = "new-public"; - model.MasterPasswordUnlockData.Email = user.Email; - model.MasterPasswordUnlockData.KdfType = Enums.KdfType.Argon2id; - model.MasterPasswordUnlockData.KdfIterations = 3; - model.MasterPasswordUnlockData.KdfMemory = 64; - model.MasterPasswordUnlockData.KdfParallelism = 4; + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey = "new-public"; sutProvider.GetDependency().CheckPasswordAsync(user, model.OldMasterKeyAuthenticationHash) .Returns(true); @@ -94,27 +91,350 @@ public class RotateUserAccountKeysCommandTests } [Theory, BitAutoData] - public async Task RotatesCorrectly(SutProvider sutProvider, User user, + public async Task RotateUserAccountKeysAsync_V1_Success(SutProvider sutProvider, User user, RotateUserAccountKeysData model) { - user.Kdf = Enums.KdfType.Argon2id; - user.KdfIterations = 3; - user.KdfMemory = 64; - user.KdfParallelism = 4; - - model.MasterPasswordUnlockData.Email = user.Email; - model.MasterPasswordUnlockData.KdfType = Enums.KdfType.Argon2id; - model.MasterPasswordUnlockData.KdfIterations = 3; - model.MasterPasswordUnlockData.KdfMemory = 64; - model.MasterPasswordUnlockData.KdfParallelism = 4; - - model.AccountPublicKey = user.PublicKey; + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); sutProvider.GetDependency().CheckPasswordAsync(user, model.OldMasterKeyAuthenticationHash) .Returns(true); var result = await sutProvider.Sut.RotateUserAccountKeysAsync(user, model); - Assert.Equal(IdentityResult.Success, result); } + + [Theory, BitAutoData] + public async Task RotateUserAccountKeysAsync_UpgradeV1ToV2_Success(SutProvider sutProvider, User user, + RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + + sutProvider.GetDependency().CheckPasswordAsync(user, model.OldMasterKeyAuthenticationHash) + .Returns(true); + + var result = await sutProvider.Sut.RotateUserAccountKeysAsync(user, model); + Assert.Equal(IdentityResult.Success, result); + Assert.Equal(user.SecurityState, model.AccountKeys.SecurityStateData!.SecurityState); + } + + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_PublicKeyChange_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + + model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey = "new-public"; + var saveEncryptedDataActions = new List(); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V2User_PrivateKeyNotXChaCha20_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey = "2.xxx"; + + var saveEncryptedDataActions = new List(); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V1User_PrivateKeyNotAesCbcHmac_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey = "7.xxx"; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("The provided account private key was not wrapped with AES-256-CBC-HMAC", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V1_Success(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + + var saveEncryptedDataActions = new List(); + await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions); + Assert.Empty(saveEncryptedDataActions); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V2_Success(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + + var saveEncryptedDataActions = new List(); + await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions); + Assert.NotEmpty(saveEncryptedDataActions); + Assert.Equal(user.SecurityState, model.AccountKeys.SecurityStateData!.SecurityState); + } + + + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V2User_VerifyingKeyMismatch_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.SignatureKeyPairData.VerifyingKey = "different-verifying-key"; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("The provided verifying key does not match the user's current verifying key.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V2User_SignedPublicKeyNullOrEmpty_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey = null; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("No signed public key provided, but the user already has a signature key pair.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_V2User_WrappedSigningKeyNotXChaCha20_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.SignatureKeyPairData.WrappedSigningKey = "2.xxx"; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("The provided signing key data is not wrapped with XChaCha20-Poly1305.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeys_UpgradeToV2_InvalidVerifyingKey_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.SignatureKeyPairData.VerifyingKey = ""; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("The provided signature key pair data does not contain a valid verifying key.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_UpgradeToV2_IncorrectlyWrappedPrivateKey_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey = "2.abc"; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("The provided private key encryption key is not wrapped with XChaCha20-Poly1305.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_UpgradeToV2_NoSignedPublicKey_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey = null; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("No signed public key provided, but the user already has a signature key pair.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_UpgradeToV2_NoSecurityState_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.SecurityStateData = null; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("No signed security state provider for V2 user", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_RotateV2_NoSignatureKeyPair_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV2ExistingUser(user, signatureRepository); + SetV2ModelUser(model); + model.AccountKeys.SignatureKeyPairData = null; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("Signature key pair data is required for V2 encryption.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_GetEncryptionType_EmptyString_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey = ""; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("Invalid encryption type string.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateAccountKeysAsync_GetEncryptionType_InvalidString_Rejects(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + SetTestKdfAndSaltForUserAndModel(user, model); + var signatureRepository = sutProvider.GetDependency(); + SetV1ExistingUser(user, signatureRepository); + SetV1ModelUser(model); + model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey = "9.xxx"; + + var saveEncryptedDataActions = new List(); + var ex = await Assert.ThrowsAsync(async () => await sutProvider.Sut.UpdateAccountKeysAsync(model, user, saveEncryptedDataActions)); + Assert.Equal("Invalid encryption type string.", ex.Message); + } + + [Theory, BitAutoData] + public async Task UpdateUserData_RevisionDateChanged_Success(SutProvider sutProvider, User user, RotateUserAccountKeysData model) + { + var oldDate = new DateTime(2017, 1, 1); + + var cipher = Substitute.For(); + cipher.RevisionDate = oldDate; + model.Ciphers = [cipher]; + + var folder = Substitute.For(); + folder.RevisionDate = oldDate; + model.Folders = [folder]; + + var send = Substitute.For(); + send.RevisionDate = oldDate; + model.Sends = [send]; + + var saveEncryptedDataActions = new List(); + + sutProvider.Sut.UpdateUserData(model, user, saveEncryptedDataActions); + foreach (var dataAction in saveEncryptedDataActions) + { + await dataAction.Invoke(); + } + + var updatedCiphers = sutProvider.GetDependency() + .ReceivedCalls() + .FirstOrDefault(call => call.GetMethodInfo().Name == "UpdateForKeyRotation")? + .GetArguments()[1] as IEnumerable; + foreach (var updatedCipher in updatedCiphers!) + { + var oldCipher = model.Ciphers.FirstOrDefault(c => c.Id == updatedCipher.Id); + Assert.NotEqual(oldDate, updatedCipher.RevisionDate); + } + + var updatedFolders = sutProvider.GetDependency() + .ReceivedCalls() + .FirstOrDefault(call => call.GetMethodInfo().Name == "UpdateForKeyRotation")? + .GetArguments()[1] as IEnumerable; + foreach (var updatedFolder in updatedFolders!) + { + var oldFolder = model.Folders.FirstOrDefault(f => f.Id == updatedFolder.Id); + Assert.NotEqual(oldDate, updatedFolder.RevisionDate); + } + + var updatedSends = sutProvider.GetDependency() + .ReceivedCalls() + .FirstOrDefault(call => call.GetMethodInfo().Name == "UpdateForKeyRotation")? + .GetArguments()[1] as IEnumerable; + foreach (var updatedSend in updatedSends!) + { + var oldSend = model.Sends.FirstOrDefault(s => s.Id == updatedSend.Id); + Assert.NotEqual(oldDate, updatedSend.RevisionDate); + } + } + + // Helper functions to set valid test parameters that match each other to the model and user. + private static void SetTestKdfAndSaltForUserAndModel(User user, RotateUserAccountKeysData model) + { + user.Kdf = Enums.KdfType.Argon2id; + user.KdfIterations = 3; + user.KdfMemory = 64; + user.KdfParallelism = 4; + model.MasterPasswordUnlockData.KdfType = Enums.KdfType.Argon2id; + model.MasterPasswordUnlockData.KdfIterations = 3; + model.MasterPasswordUnlockData.KdfMemory = 64; + model.MasterPasswordUnlockData.KdfParallelism = 4; + // The email is the salt for the KDF and is validated currently. + user.Email = model.MasterPasswordUnlockData.Email; + } + + private static void SetV1ExistingUser(User user, IUserSignatureKeyPairRepository userSignatureKeyPairRepository) + { + user.PrivateKey = "2.abc"; + user.PublicKey = "public"; + user.SignedPublicKey = null; + userSignatureKeyPairRepository.GetByUserIdAsync(user.Id).ReturnsNull(); + } + + private static void SetV2ExistingUser(User user, IUserSignatureKeyPairRepository userSignatureKeyPairRepository) + { + user.PrivateKey = "7.abc"; + user.PublicKey = "public"; + user.SignedPublicKey = "signed-public"; + userSignatureKeyPairRepository.GetByUserIdAsync(user.Id).Returns(new SignatureKeyPairData(SignatureAlgorithm.Ed25519, "7.abc", "verifying-key")); + } + + private static void SetV1ModelUser(RotateUserAccountKeysData model) + { + model.AccountKeys.PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData("2.abc", "public", null); + model.AccountKeys.SignatureKeyPairData = null; + model.AccountKeys.SecurityStateData = null; + } + + private static void SetV2ModelUser(RotateUserAccountKeysData model) + { + model.AccountKeys.PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData("7.abc", "public", "signed-public"); + model.AccountKeys.SignatureKeyPairData = new SignatureKeyPairData(SignatureAlgorithm.Ed25519, "7.abc", "verifying-key"); + model.AccountKeys.SecurityStateData = new SecurityStateData + { + SecurityState = "abc", + SecurityVersion = 2, + }; + } } diff --git a/test/Core.Test/Models/Business/BillingCustomerDiscountTests.cs b/test/Core.Test/Models/Business/BillingCustomerDiscountTests.cs new file mode 100644 index 0000000000..6dbe829da5 --- /dev/null +++ b/test/Core.Test/Models/Business/BillingCustomerDiscountTests.cs @@ -0,0 +1,497 @@ +using Bit.Core.Models.Business; +using Bit.Test.Common.AutoFixture.Attributes; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Models.Business; + +public class BillingCustomerDiscountTests +{ + [Theory] + [BitAutoData] + public void Constructor_PercentageDiscount_SetsIdActivePercentOffAndAppliesTo(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 25.5m, + AmountOff = null, + AppliesTo = new CouponAppliesTo + { + Products = new List { "product1", "product2" } + } + }, + End = null // Active discount + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(couponId, result.Id); + Assert.True(result.Active); + Assert.Equal(25.5m, result.PercentOff); + Assert.Null(result.AmountOff); + Assert.NotNull(result.AppliesTo); + Assert.Equal(2, result.AppliesTo.Count); + Assert.Contains("product1", result.AppliesTo); + Assert.Contains("product2", result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_AmountDiscount_ConvertsFromCentsToDollars(string couponId) + { + // Arrange - Stripe sends 1400 cents for $14.00 + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = null, + AmountOff = 1400, // 1400 cents + AppliesTo = new CouponAppliesTo + { + Products = new List() + } + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(couponId, result.Id); + Assert.True(result.Active); + Assert.Null(result.PercentOff); + Assert.Equal(14.00m, result.AmountOff); // Converted to dollars + Assert.NotNull(result.AppliesTo); + Assert.Empty(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_InactiveDiscount_SetsActiveToFalse(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 15m + }, + End = DateTime.UtcNow.AddDays(-1) // Expired discount + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(couponId, result.Id); + Assert.False(result.Active); + Assert.Equal(15m, result.PercentOff); + } + + [Fact] + public void Constructor_NullCoupon_SetsDiscountPropertiesToNull() + { + // Arrange + var discount = new Discount + { + Coupon = null, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.Id); + Assert.True(result.Active); + Assert.Null(result.PercentOff); + Assert.Null(result.AmountOff); + Assert.Null(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_NullAmountOff_SetsAmountOffToNull(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 10m, + AmountOff = null + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_ZeroAmountOff_ConvertsCorrectly(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + AmountOff = 0 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(0m, result.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_LargeAmountOff_ConvertsCorrectly(string couponId) + { + // Arrange - $100.00 discount + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + AmountOff = 10000 // 10000 cents = $100.00 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(100.00m, result.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_SmallAmountOff_ConvertsCorrectly(string couponId) + { + // Arrange - $0.50 discount + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + AmountOff = 50 // 50 cents = $0.50 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(0.50m, result.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_BothDiscountTypes_SetsPercentOffAndAmountOff(string couponId) + { + // Arrange - Coupon with both percentage and amount (edge case) + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 20m, + AmountOff = 500 // $5.00 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(20m, result.PercentOff); + Assert.Equal(5.00m, result.AmountOff); + } + + [Theory] + [BitAutoData] + public void Constructor_WithNullAppliesTo_SetsAppliesToNull(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 10m, + AppliesTo = null + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_WithNullProductsList_SetsAppliesToNull(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 10m, + AppliesTo = new CouponAppliesTo + { + Products = null + } + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_WithDecimalAmountOff_RoundsCorrectly(string couponId) + { + // Arrange - 1425 cents = $14.25 + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + AmountOff = 1425 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Equal(14.25m, result.AmountOff); + } + + [Fact] + public void Constructor_DefaultConstructor_InitializesAllPropertiesToNullOrFalse() + { + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(); + + // Assert + Assert.Null(result.Id); + Assert.False(result.Active); + Assert.Null(result.PercentOff); + Assert.Null(result.AmountOff); + Assert.Null(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_WithFutureEndDate_SetsActiveToFalse(string couponId) + { + // Arrange - Discount expires in the future + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 20m + }, + End = DateTime.UtcNow.AddDays(30) // Expires in 30 days + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.False(result.Active); // Should be inactive because End is not null + } + + [Theory] + [BitAutoData] + public void Constructor_WithPastEndDate_SetsActiveToFalse(string couponId) + { + // Arrange - Discount already expired + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 20m + }, + End = DateTime.UtcNow.AddDays(-30) // Expired 30 days ago + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.False(result.Active); // Should be inactive because End is not null + } + + [Fact] + public void Constructor_WithNullCouponId_SetsIdToNull() + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = null, + PercentOff = 20m + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.Id); + Assert.True(result.Active); + Assert.Equal(20m, result.PercentOff); + } + + [Theory] + [BitAutoData] + public void Constructor_WithNullPercentOff_SetsPercentOffToNull(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = null, + AmountOff = 1000 + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.Null(result.PercentOff); + Assert.Equal(10.00m, result.AmountOff); + } + + [Fact] + public void Constructor_WithCompleteStripeDiscount_MapsAllProperties() + { + // Arrange - Comprehensive test with all Stripe Discount properties set + var discount = new Discount + { + Coupon = new Coupon + { + Id = "premium_discount_2024", + PercentOff = 25m, + AmountOff = 1500, // $15.00 + AppliesTo = new CouponAppliesTo + { + Products = new List { "prod_premium", "prod_family", "prod_teams" } + } + }, + End = null // Active + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert - Verify all properties mapped correctly + Assert.Equal("premium_discount_2024", result.Id); + Assert.True(result.Active); + Assert.Equal(25m, result.PercentOff); + Assert.Equal(15.00m, result.AmountOff); + Assert.NotNull(result.AppliesTo); + Assert.Equal(3, result.AppliesTo.Count); + Assert.Contains("prod_premium", result.AppliesTo); + Assert.Contains("prod_family", result.AppliesTo); + Assert.Contains("prod_teams", result.AppliesTo); + } + + [Fact] + public void Constructor_WithMinimalStripeDiscount_HandlesNullsGracefully() + { + // Arrange - Minimal Stripe Discount with most properties null + var discount = new Discount + { + Coupon = new Coupon + { + Id = null, + PercentOff = null, + AmountOff = null, + AppliesTo = null + }, + End = DateTime.UtcNow.AddDays(10) // Has end date + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert - Should handle all nulls gracefully + Assert.Null(result.Id); + Assert.False(result.Active); + Assert.Null(result.PercentOff); + Assert.Null(result.AmountOff); + Assert.Null(result.AppliesTo); + } + + [Theory] + [BitAutoData] + public void Constructor_WithEmptyProductsList_PreservesEmptyList(string couponId) + { + // Arrange + var discount = new Discount + { + Coupon = new Coupon + { + Id = couponId, + PercentOff = 10m, + AppliesTo = new CouponAppliesTo + { + Products = new List() // Empty but not null + } + }, + End = null + }; + + // Act + var result = new SubscriptionInfo.BillingCustomerDiscount(discount); + + // Assert + Assert.NotNull(result.AppliesTo); + Assert.Empty(result.AppliesTo); + } +} diff --git a/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs index dee805033a..39374755eb 100644 --- a/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/CompleteSubscriptionUpdateTests.cs @@ -2,7 +2,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -17,7 +17,7 @@ public class CompleteSubscriptionUpdateTests public void UpgradeItemOptions_TeamsStarterToTeams_ReturnsCorrectOptions( Organization organization) { - var teamsStarterPlan = StaticStore.GetPlan(PlanType.TeamsStarter); + var teamsStarterPlan = MockPlans.Get(PlanType.TeamsStarter); var subscription = new Subscription { @@ -35,7 +35,7 @@ public class CompleteSubscriptionUpdateTests } }; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -66,7 +66,7 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -102,7 +102,7 @@ public class CompleteSubscriptionUpdateTests } }; - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -173,7 +173,7 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -209,7 +209,7 @@ public class CompleteSubscriptionUpdateTests } }; - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var updatedSubscriptionData = new SubscriptionData { @@ -277,8 +277,8 @@ public class CompleteSubscriptionUpdateTests public void RevertItemOptions_TeamsStarterToTeams_ReturnsCorrectOptions( Organization organization) { - var teamsStarterPlan = StaticStore.GetPlan(PlanType.TeamsStarter); - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); + var teamsStarterPlan = MockPlans.Get(PlanType.TeamsStarter); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); var subscription = new Subscription { @@ -325,8 +325,8 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var subscription = new Subscription { @@ -431,8 +431,8 @@ public class CompleteSubscriptionUpdateTests // 5 purchased, 1 base organization.MaxStorageGb = 6; - var teamsMonthlyPlan = StaticStore.GetPlan(PlanType.TeamsMonthly); - var enterpriseMonthlyPlan = StaticStore.GetPlan(PlanType.EnterpriseMonthly); + var teamsMonthlyPlan = MockPlans.Get(PlanType.TeamsMonthly); + var enterpriseMonthlyPlan = MockPlans.Get(PlanType.EnterpriseMonthly); var subscription = new Subscription { diff --git a/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs index b6e9f63640..d96f9fea95 100644 --- a/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SeatSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class SeatSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { @@ -69,7 +69,7 @@ public class SeatSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { diff --git a/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs index 6a411363a0..1f75b6a23a 100644 --- a/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SecretsManagerSubscriptionUpdateTests.cs @@ -4,7 +4,7 @@ using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; @@ -16,13 +16,13 @@ public class SecretsManagerSubscriptionUpdateTests private static TheoryData ToPlanTheory(List types) { var theoryData = new TheoryData(); - var plans = types.Select(StaticStore.GetPlan).ToArray(); + var plans = types.Select(MockPlans.Get).ToArray(); theoryData.AddRange(plans); return theoryData; } public static TheoryData NonSmPlans => - ToPlanTheory([PlanType.Custom, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019]); + ToPlanTheory([PlanType.Custom, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2025, PlanType.FamiliesAnnually2019]); public static TheoryData SmPlans => ToPlanTheory([ PlanType.EnterpriseAnnually2019, diff --git a/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs index 3663277933..a1e9669c87 100644 --- a/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/ServiceAccountSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class ServiceAccountSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var subscription = new Subscription { @@ -69,7 +69,7 @@ public class ServiceAccountSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 5; var subscription = new Subscription diff --git a/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs index ee9dc615b6..d9fcaf991e 100644 --- a/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/SmSeatSubscriptionUpdateTests.cs @@ -1,7 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -27,7 +27,7 @@ public class SmSeatSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 3; var subscription = new Subscription @@ -70,7 +70,7 @@ public class SmSeatSubscriptionUpdateTests [BitAutoData(PlanType.TeamsAnnually)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType, Organization organization) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); organization.PlanType = planType; var quantity = 5; var subscription = new Subscription diff --git a/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs b/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs index 79b29fcd0c..21326c5324 100644 --- a/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs +++ b/test/Core.Test/Models/Business/StorageSubscriptionUpdateTests.cs @@ -1,6 +1,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Models.Business; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture.Attributes; using Stripe; using Xunit; @@ -26,7 +26,7 @@ public class StorageSubscriptionUpdateTests public void UpgradeItemsOptions_ReturnsCorrectOptions(PlanType planType) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var subscription = new Subscription { Items = new StripeList @@ -77,7 +77,7 @@ public class StorageSubscriptionUpdateTests [BitAutoData(PlanType.TeamsStarter)] public void RevertItemsOptions_ReturnsCorrectOptions(PlanType planType) { - var plan = StaticStore.GetPlan(planType); + var plan = MockPlans.Get(planType); var subscription = new Subscription { Items = new StripeList diff --git a/test/Core.Test/Models/Business/SubscriptionInfoTests.cs b/test/Core.Test/Models/Business/SubscriptionInfoTests.cs new file mode 100644 index 0000000000..ef6a61ad5d --- /dev/null +++ b/test/Core.Test/Models/Business/SubscriptionInfoTests.cs @@ -0,0 +1,125 @@ +using Bit.Core.Models.Business; +using Stripe; +using Xunit; + +namespace Bit.Core.Test.Models.Business; + +public class SubscriptionInfoTests +{ + [Fact] + public void BillingSubscriptionItem_NullPlan_HandlesGracefully() + { + // Arrange - SubscriptionItem with null Plan + var subscriptionItem = new SubscriptionItem + { + Plan = null, + Quantity = 1 + }; + + // Act + var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem); + + // Assert - Should handle null Plan gracefully + Assert.Null(result.ProductId); + Assert.Null(result.Name); + Assert.Equal(0m, result.Amount); // Defaults to 0 when Plan is null + Assert.Null(result.Interval); + Assert.Equal(1, result.Quantity); + Assert.False(result.SponsoredSubscriptionItem); + Assert.False(result.AddonSubscriptionItem); + } + + [Fact] + public void BillingSubscriptionItem_NullAmount_SetsToZero() + { + // Arrange - SubscriptionItem with Plan but null Amount + var subscriptionItem = new SubscriptionItem + { + Plan = new Plan + { + ProductId = "prod_test", + Nickname = "Test Plan", + Amount = null, // Null amount + Interval = "month" + }, + Quantity = 1 + }; + + // Act + var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem); + + // Assert - Should default to 0 when Amount is null + Assert.Equal("prod_test", result.ProductId); + Assert.Equal("Test Plan", result.Name); + Assert.Equal(0m, result.Amount); // Business rule: defaults to 0 when null + Assert.Equal("month", result.Interval); + Assert.Equal(1, result.Quantity); + } + + [Fact] + public void BillingSubscriptionItem_ZeroAmount_PreservesZero() + { + // Arrange - SubscriptionItem with Plan and zero Amount + var subscriptionItem = new SubscriptionItem + { + Plan = new Plan + { + ProductId = "prod_test", + Nickname = "Test Plan", + Amount = 0, // Zero amount (0 cents) + Interval = "month" + }, + Quantity = 1 + }; + + // Act + var result = new SubscriptionInfo.BillingSubscription.BillingSubscriptionItem(subscriptionItem); + + // Assert - Should preserve zero amount + Assert.Equal("prod_test", result.ProductId); + Assert.Equal("Test Plan", result.Name); + Assert.Equal(0m, result.Amount); // Zero amount preserved + Assert.Equal("month", result.Interval); + } + + [Fact] + public void BillingUpcomingInvoice_ZeroAmountDue_ConvertsToZero() + { + // Arrange - Invoice with zero AmountDue + // Note: Stripe's Invoice.AmountDue is non-nullable long, so we test with 0 + // The null-coalescing operator (?? 0) in the constructor handles the case where + // ConvertFromStripeMinorUnits returns null, but since AmountDue is non-nullable, + // this test verifies the conversion path works correctly for zero values + var invoice = new Invoice + { + AmountDue = 0, // Zero amount due (0 cents) + Created = DateTime.UtcNow + }; + + // Act + var result = new SubscriptionInfo.BillingUpcomingInvoice(invoice); + + // Assert - Should convert zero correctly + Assert.Equal(0m, result.Amount); + Assert.NotNull(result.Date); + } + + [Fact] + public void BillingUpcomingInvoice_ValidAmountDue_ConvertsCorrectly() + { + // Arrange - Invoice with valid AmountDue + var invoice = new Invoice + { + AmountDue = 2500, // 2500 cents = $25.00 + Created = DateTime.UtcNow + }; + + // Act + var result = new SubscriptionInfo.BillingUpcomingInvoice(invoice); + + // Assert - Should convert correctly + Assert.Equal(25.00m, result.Amount); // Converted from cents + Assert.NotNull(result.Date); + } +} + diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs index 786a6f6c0d..a6db6ae8fd 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/CancelSponsorshipCommandTestsBase.cs @@ -1,4 +1,5 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,7 +13,7 @@ public abstract class CancelSponsorshipCommandTestsBase : FamiliesForEnterpriseT protected async Task AssertRemovedSponsoredPaymentAsync(Organization sponsoredOrg, OrganizationSponsorship sponsorship, SutProvider sutProvider) { - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .RemoveOrganizationSponsorshipAsync(sponsoredOrg, sponsorship); await sutProvider.GetDependency().Received(1).UpsertAsync(sponsoredOrg); if (sponsorship != null) @@ -46,7 +47,7 @@ OrganizationSponsorship sponsorship, SutProvider sutProvider) protected static async Task AssertDidNotRemoveSponsoredPaymentAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .RemoveOrganizationSponsorshipAsync(default, default); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .UpsertAsync(default); diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs index 69e7183c65..127cc7e502 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/Cloud/SetUpSponsorshipCommandTests.cs @@ -1,10 +1,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Cloud; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationSponsorshipFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -82,7 +82,7 @@ public class SetUpSponsorshipCommandTests : FamiliesForEnterpriseTestsBase private static async Task AssertDidNotSetUpAsync(SutProvider sutProvider) { - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() .SponsorOrganizationAsync(default, default); await sutProvider.GetDependency() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs index 5feee0f13a..515b4d7ba1 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/FamiliesForEnterpriseTestsBase.cs @@ -1,22 +1,22 @@ using Bit.Core.Billing.Enums; using Bit.Core.Enums; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; namespace Bit.Core.Test.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise; public abstract class FamiliesForEnterpriseTestsBase { public static IEnumerable EnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable NonEnterprisePlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Enterprise).Select(p => new object[] { p }); public static IEnumerable FamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier == ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier == ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonFamiliesPlanTypes => - Enum.GetValues().Where(p => StaticStore.GetPlan(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); + Enum.GetValues().Where(p => MockPlans.Get(p).ProductTier != ProductTierType.Families).Select(p => new object[] { p }); public static IEnumerable NonConfirmedOrganizationUsersStatuses => Enum.GetValues() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs index 02ae40798b..83e1487b01 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/AddSecretsManagerSubscriptionCommandTests.cs @@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Business; using Bit.Core.Models.StaticStore; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions; using Bit.Core.Services; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -42,7 +43,7 @@ public class AddSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(plan); await sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts); @@ -54,7 +55,7 @@ public class AddSecretsManagerSubscriptionCommandTests c.AdditionalServiceAccounts == additionalServiceAccounts && c.AdditionalSeats == organization.Seats.GetValueOrDefault())); - await sutProvider.GetDependency().Received() + await sutProvider.GetDependency().Received() .AddSecretsManagerToSubscription(organization, plan, additionalSmSeats, additionalServiceAccounts); // TODO: call ReferenceEventService - see AC-1481 @@ -88,7 +89,7 @@ public class AddSecretsManagerSubscriptionCommandTests organization.GatewayCustomerId = null; organization.PlanType = PlanType.EnterpriseAnnually; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts)); Assert.Contains("No payment method found.", exception.Message); @@ -106,7 +107,7 @@ public class AddSecretsManagerSubscriptionCommandTests organization.GatewaySubscriptionId = null; organization.PlanType = PlanType.EnterpriseAnnually; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.SignUpAsync(organization, additionalSmSeats, additionalServiceAccounts)); Assert.Contains("No subscription found.", exception.Message); @@ -139,7 +140,7 @@ public class AddSecretsManagerSubscriptionCommandTests provider.Type = ProviderType.Msp; sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(provider); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType) - .Returns(StaticStore.GetPlan(organization.PlanType)); + .Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SignUpAsync(organization, 10, 10)); @@ -150,7 +151,7 @@ public class AddSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AddSecretsManagerToSubscription(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs index 8b00741215..510433a2fa 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpdateSecretsManagerSubscriptionCommandTests.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.Business; @@ -11,7 +12,7 @@ using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -26,7 +27,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests private static TheoryData ToPlanTheory(List types) { var theoryData = new TheoryData(); - var plans = types.Select(StaticStore.GetPlan).ToArray(); + var plans = types.Select(MockPlans.Get).ToArray(); theoryData.AddRange(plans); return theoryData; } @@ -86,9 +87,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -136,9 +137,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustSmSeatsAsync(organization, plan, update.SmSeatsExcludingBase); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .AdjustServiceAccountsAsync(organization, plan, update.SmServiceAccountsExcludingBase); // TODO: call ReferenceEventService - see AC-1481 @@ -164,7 +165,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, autoscaling).AdjustSeats(2); sutProvider.GetDependency().SelfHosted.Returns(true); @@ -180,7 +181,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider, Organization organization) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); organization.UseSecretsManager = false; var update = new SecretsManagerSubscriptionUpdate(organization, plan, false); @@ -258,7 +259,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests await sutProvider.Sut.UpdateSubscriptionAsync(update); - await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( + await sutProvider.GetDependency().Received(1).AdjustServiceAccountsAsync( Arg.Is(o => o.Id == organizationId), plan, expectedSmServiceAccountsExcludingBase); @@ -278,21 +279,27 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { // Arrange - const int seatCount = 10; - var existingSeatCount = 9; - // Make sure Password Manager seats is greater or equal to Secrets Manager seats - organization.Seats = seatCount; - var plan = StaticStore.GetPlan(organization.PlanType); + const int initialSeatCount = 9; + const int maxSeatCount = 20; + // This represents the total number of users allowed in the organization. + organization.Seats = maxSeatCount; + // This represents the number of Secrets Manager users allowed in the organization. + organization.SmSeats = initialSeatCount; + // This represents the upper limit of Secrets Manager seats that can be automatically scaled. + organization.MaxAutoscaleSmSeats = maxSeatCount; + + organization.PlanType = PlanType.EnterpriseAnnually; + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { - SmSeats = seatCount, - MaxAutoscaleSmSeats = seatCount + SmSeats = 8, + MaxAutoscaleSmSeats = maxSeatCount }; sutProvider.GetDependency() .GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id) - .Returns(existingSeatCount); + .Returns(5); // Act await sutProvider.Sut.UpdateSubscriptionAsync(update); @@ -316,21 +323,29 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { // Arrange - const int seatCount = 10; - const int existingSeatCount = 10; - var ownerDetailsList = new List { new() { Email = "owner@example.com" } }; + const int initialSeatCount = 5; + const int maxSeatCount = 10; + + // This represents the total number of users allowed in the organization. + organization.Seats = maxSeatCount; + // This represents the number of Secrets Manager users allowed in the organization. + organization.SmSeats = initialSeatCount; + // This represents the upper limit of Secrets Manager seats that can be automatically scaled. + organization.MaxAutoscaleSmSeats = maxSeatCount; + + var ownerDetailsList = new List { new() { Email = "owner@example.com" } }; + organization.PlanType = PlanType.EnterpriseAnnually; + var plan = MockPlans.Get(organization.PlanType); - // The amount of seats for users in an organization - var plan = StaticStore.GetPlan(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { - SmSeats = seatCount, - MaxAutoscaleSmSeats = seatCount + SmSeats = maxSeatCount, + MaxAutoscaleSmSeats = maxSeatCount }; sutProvider.GetDependency() .GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id) - .Returns(existingSeatCount); + .Returns(maxSeatCount); sutProvider.GetDependency() .GetManyByMinimumRoleAsync(organization.Id, OrganizationUserType.Owner) .Returns(ownerDetailsList); @@ -340,15 +355,14 @@ public class UpdateSecretsManagerSubscriptionCommandTests // Assert - // Currently being called once each for different validation methods await sutProvider.GetDependency() - .Received(2) + .Received(1) .GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id); await sutProvider.GetDependency() .Received(1) .SendSecretsManagerMaxSeatLimitReachedEmailAsync(Arg.Is(organization), - Arg.Is(seatCount), + Arg.Is(maxSeatCount), Arg.Is>(emails => emails.Contains(ownerDetailsList[0].Email))); } @@ -359,7 +373,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmSeats = null; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustSeats(1); var exception = await Assert.ThrowsAsync( @@ -375,7 +389,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustSeats(-2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -391,7 +405,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustSeats(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -409,7 +423,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmSeats = 9; organization.MaxAutoscaleSmSeats = 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustSeats(2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -423,7 +437,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = organization.SmSeats + 10, @@ -442,7 +456,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = 0, @@ -462,7 +476,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmSeats = 8; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = 7, @@ -485,7 +499,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests var smServiceAccounts = 300; var existingServiceAccountCount = 299; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccounts, @@ -518,7 +532,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { var smServiceAccounts = 300; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccounts, @@ -558,7 +572,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.SmServiceAccounts = null; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustServiceAccounts(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -572,7 +586,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests Organization organization, SutProvider sutProvider) { - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustServiceAccounts(-2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -588,7 +602,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests SutProvider sutProvider) { organization.PlanType = planType; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false).AdjustServiceAccounts(1); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -606,7 +620,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = 9; organization.MaxAutoscaleSmServiceAccounts = 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, true).AdjustServiceAccounts(2); var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -626,7 +640,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = smServiceAccount - 5; organization.MaxAutoscaleSmServiceAccounts = 2 * smServiceAccount; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = smServiceAccount, @@ -649,7 +663,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmServiceAccounts = newSmServiceAccounts - 10; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmServiceAccounts = newSmServiceAccounts, @@ -694,7 +708,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.SmSeats = smSeats - 1; organization.MaxAutoscaleSmSeats = smSeats * 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { SmSeats = smSeats, @@ -715,7 +729,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; organization.SmSeats = 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmSeats = 3 @@ -735,7 +749,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests { organization.PlanType = planType; organization.SmSeats = 2; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmSeats = 2 @@ -756,7 +770,7 @@ public class UpdateSecretsManagerSubscriptionCommandTests organization.PlanType = planType; organization.SmServiceAccounts = 3; - var plan = StaticStore.GetPlan(organization.PlanType); + var plan = MockPlans.Get(organization.PlanType); var update = new SecretsManagerSubscriptionUpdate(organization, plan, false) { MaxAutoscaleSmServiceAccounts = 3 }; var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.UpdateSubscriptionAsync(update)); @@ -766,9 +780,9 @@ public class UpdateSecretsManagerSubscriptionCommandTests private static async Task VerifyDependencyNotCalledAsync(SutProvider sutProvider) { - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustSmSeatsAsync(Arg.Any(), Arg.Any(), Arg.Any()); - await sutProvider.GetDependency().DidNotReceive() + await sutProvider.GetDependency().DidNotReceive() .AdjustServiceAccountsAsync(Arg.Any(), Arg.Any(), Arg.Any()); // TODO: call ReferenceEventService - see AC-1481 await sutProvider.GetDependency().DidNotReceive() diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs index 704f89ba3f..223047ee07 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs @@ -1,6 +1,8 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Business; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.OrganizationFeatures.OrganizationSubscriptions; @@ -8,7 +10,7 @@ using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; using Bit.Core.Test.AutoFixture.OrganizationFixtures; -using Bit.Core.Utilities; +using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -45,7 +47,7 @@ public class UpgradeOrganizationPlanCommandTests SutProvider sutProvider) { upgrade.Plan = organization.PlanType; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); @@ -61,7 +63,7 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSmSeats = 10; upgrade.AdditionalServiceAccounts = 10; sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade)); Assert.Contains("already on this plan", exception.Message); @@ -73,11 +75,11 @@ public class UpgradeOrganizationPlanCommandTests SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.AdditionalSmSeats = 10; upgrade.AdditionalSeats = 10; upgrade.Plan = PlanType.TeamsAnnually; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -104,7 +106,7 @@ public class UpgradeOrganizationPlanCommandTests organization.PlanType = PlanType.FamiliesAnnually; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); organizationUpgrade.AdditionalSeats = 30; organizationUpgrade.UseSecretsManager = true; @@ -113,7 +115,7 @@ public class UpgradeOrganizationPlanCommandTests organizationUpgrade.AdditionalStorageGb = 3; organizationUpgrade.Plan = planType; - sutProvider.GetDependency().GetPlanOrThrow(organizationUpgrade.Plan).Returns(StaticStore.GetPlan(organizationUpgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(organizationUpgrade.Plan).Returns(MockPlans.Get(organizationUpgrade.Plan)); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -121,9 +123,9 @@ public class UpgradeOrganizationPlanCommandTests Users = 1 }); await sutProvider.Sut.UpgradePlanAsync(organization.Id, organizationUpgrade); - await sutProvider.GetDependency().Received(1).AdjustSubscription( + await sutProvider.GetDependency().Received(1).AdjustSubscription( organization, - StaticStore.GetPlan(planType), + MockPlans.Get(planType), organizationUpgrade.AdditionalSeats, organizationUpgrade.UseSecretsManager, organizationUpgrade.AdditionalSmSeats, @@ -141,12 +143,12 @@ public class UpgradeOrganizationPlanCommandTests public async Task UpgradePlan_SM_Passes(PlanType planType, Organization organization, OrganizationUpgrade upgrade, SutProvider sutProvider) { - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.Plan = planType; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); - var plan = StaticStore.GetPlan(upgrade.Plan); + var plan = MockPlans.Get(upgrade.Plan); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); @@ -184,10 +186,10 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSeats = 15; upgrade.AdditionalSmSeats = 1; upgrade.AdditionalServiceAccounts = 0; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); organization.SmSeats = 2; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() @@ -218,11 +220,11 @@ public class UpgradeOrganizationPlanCommandTests upgrade.AdditionalSeats = 15; upgrade.AdditionalSmSeats = 1; upgrade.AdditionalServiceAccounts = 0; - sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(StaticStore.GetPlan(upgrade.Plan)); + sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); organization.SmSeats = 1; organization.SmServiceAccounts = currentServiceAccounts; - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(StaticStore.GetPlan(organization.PlanType)); + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() @@ -241,4 +243,134 @@ public class UpgradeOrganizationPlanCommandTests await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().ReplaceAndUpdateCacheAsync(default); } + + [Theory] + [FreeOrganizationUpgradeCustomize, BitAutoData] + public async Task UpgradePlan_WhenOrganizationIsMissingPublicAndPrivateKeys_Backfills( + Organization organization, + OrganizationUpgrade upgrade, + string newPublicKey, + string newPrivateKey, + SutProvider sutProvider) + { + organization.PublicKey = null; + organization.PrivateKey = null; + + upgrade.Plan = PlanType.TeamsAnnually; + upgrade.Keys = new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: newPrivateKey, + publicKey: newPublicKey); + upgrade.AdditionalSeats = 10; + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + sutProvider.GetDependency() + .GetPlanOrThrow(upgrade.Plan) + .Returns(MockPlans.Get(upgrade.Plan)); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Sponsored = 0, Users = 1 }); + + // Act + await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); + + // Assert + Assert.Equal(newPublicKey, organization.PublicKey); + Assert.Equal(newPrivateKey, organization.PrivateKey); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAndUpdateCacheAsync(organization); + } + + [Theory] + [FreeOrganizationUpgradeCustomize, BitAutoData] + public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotOverwriteWithNull( + Organization organization, + OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + // Arrange + const string existingPublicKey = "existing-public-key"; + const string existingPrivateKey = "existing-private-key"; + + organization.PublicKey = existingPublicKey; + organization.PrivateKey = existingPrivateKey; + + upgrade.Plan = PlanType.TeamsAnnually; + upgrade.Keys = null; + upgrade.AdditionalSeats = 10; + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + sutProvider.GetDependency() + .GetPlanOrThrow(upgrade.Plan) + .Returns(MockPlans.Get(upgrade.Plan)); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Sponsored = 0, Users = 1 }); + + // Act + await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); + + // Assert + Assert.Equal(existingPublicKey, organization.PublicKey); + Assert.Equal(existingPrivateKey, organization.PrivateKey); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAndUpdateCacheAsync(organization); + } + + [Theory] + [FreeOrganizationUpgradeCustomize, BitAutoData] + public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotBackfillWithNewKeys( + Organization organization, + OrganizationUpgrade upgrade, + SutProvider sutProvider) + { + // Arrange + const string existingPublicKey = "existing-public-key"; + const string existingPrivateKey = "existing-private-key"; + const string newPublicKey = "new-public-key"; + const string newPrivateKey = "new-private-key"; + + organization.PublicKey = existingPublicKey; + organization.PrivateKey = existingPrivateKey; + + upgrade.Plan = PlanType.TeamsAnnually; + upgrade.Keys = new PublicKeyEncryptionKeyPairData( + wrappedPrivateKey: newPrivateKey, + publicKey: newPublicKey); + upgrade.AdditionalSeats = 10; + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + sutProvider.GetDependency() + .GetPlanOrThrow(organization.PlanType) + .Returns(MockPlans.Get(organization.PlanType)); + sutProvider.GetDependency() + .GetPlanOrThrow(upgrade.Plan) + .Returns(MockPlans.Get(upgrade.Plan)); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id) + .Returns(new OrganizationSeatCounts { Sponsored = 0, Users = 1 }); + + // Act + await sutProvider.Sut.UpgradePlanAsync(organization.Id, upgrade); + + // Assert + Assert.Equal(existingPublicKey, organization.PublicKey); + Assert.Equal(existingPrivateKey, organization.PrivateKey); + await sutProvider.GetDependency() + .Received(1) + .ReplaceAndUpdateCacheAsync(organization); + } } diff --git a/test/Core.Test/Platform/Mailer/HandlebarMailRendererTests.cs b/test/Core.Test/Platform/Mailer/HandlebarMailRendererTests.cs new file mode 100644 index 0000000000..2559ae2b5f --- /dev/null +++ b/test/Core.Test/Platform/Mailer/HandlebarMailRendererTests.cs @@ -0,0 +1,172 @@ +using Bit.Core.Platform.Mail.Mailer; +using Bit.Core.Settings; +using Bit.Core.Test.Platform.Mailer.TestMail; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Platform.Mailer; + +public class HandlebarMailRendererTests +{ + [Fact] + public async Task RenderAsync_ReturnsExpectedHtmlAndTxt() + { + var logger = Substitute.For>(); + var globalSettings = new GlobalSettings { SelfHosted = false }; + var renderer = new HandlebarMailRenderer(logger, globalSettings); + + var view = new TestMailView { Name = "John Smith" }; + + var (html, txt) = await renderer.RenderAsync(view); + + Assert.Equal("Hello John Smith", html.Trim()); + Assert.Equal("Hello John Smith", txt.Trim()); + } + + [Fact] + public async Task RenderAsync_LoadsFromDisk_WhenSelfHostedAndFileExists() + { + var logger = Substitute.For>(); + var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString()); + Directory.CreateDirectory(tempDir); + + try + { + var globalSettings = new GlobalSettings + { + SelfHosted = true, + MailTemplateDirectory = tempDir + }; + + // Create test template files on disk + var htmlTemplatePath = Path.Combine(tempDir, "Bit.Core.Test.Platform.Mailer.TestMail.TestMailView.html.hbs"); + var txtTemplatePath = Path.Combine(tempDir, "Bit.Core.Test.Platform.Mailer.TestMail.TestMailView.text.hbs"); + await File.WriteAllTextAsync(htmlTemplatePath, "Custom HTML: {{Name}}"); + await File.WriteAllTextAsync(txtTemplatePath, "Custom TXT: {{Name}}"); + + var renderer = new HandlebarMailRenderer(logger, globalSettings); + var view = new TestMailView { Name = "Jane Doe" }; + + var (html, txt) = await renderer.RenderAsync(view); + + Assert.Equal("Custom HTML: Jane Doe", html.Trim()); + Assert.Equal("Custom TXT: Jane Doe", txt.Trim()); + } + finally + { + // Cleanup + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, true); + } + } + } + + [Theory] + [InlineData("../../../etc/passwd")] + [InlineData("../../../../malicious.txt")] + [InlineData("../../malicious.txt")] + [InlineData("../malicious.txt")] + public async Task ReadSourceFromDiskAsync_PrevenetsPathTraversal_WhenMaliciousPathProvided(string maliciousPath) + { + var logger = Substitute.For>(); + var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString()); + Directory.CreateDirectory(tempDir); + + try + { + var globalSettings = new GlobalSettings + { + SelfHosted = true, + MailTemplateDirectory = tempDir + }; + + // Create a malicious file outside the template directory + var maliciousFile = Path.Combine(Path.GetTempPath(), "malicious.txt"); + await File.WriteAllTextAsync(maliciousFile, "Malicious Content"); + + var renderer = new HandlebarMailRenderer(logger, globalSettings); + + // Use reflection to call the private ReadSourceFromDiskAsync method + var method = typeof(HandlebarMailRenderer).GetMethod("ReadSourceFromDiskAsync", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var task = (Task)method!.Invoke(renderer, new object[] { maliciousPath })!; + var result = await task; + + // Should return null and not load the malicious file + Assert.Null(result); + + // Verify that a warning was logged for the path traversal attempt + logger.Received(1).Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + + // Cleanup malicious file + if (File.Exists(maliciousFile)) + { + File.Delete(maliciousFile); + } + } + finally + { + // Cleanup + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, true); + } + } + } + + [Fact] + public async Task ReadSourceFromDiskAsync_AllowsValidFileWithDifferentCase_WhenCaseInsensitiveFileSystem() + { + var logger = Substitute.For>(); + var tempDir = Path.Combine(Path.GetTempPath(), Guid.NewGuid().ToString()); + Directory.CreateDirectory(tempDir); + + try + { + var globalSettings = new GlobalSettings + { + SelfHosted = true, + MailTemplateDirectory = tempDir + }; + + // Create a test template file + var templateFileName = "TestTemplate.hbs"; + var templatePath = Path.Combine(tempDir, templateFileName); + await File.WriteAllTextAsync(templatePath, "Test Content"); + + var renderer = new HandlebarMailRenderer(logger, globalSettings); + + // Try to read with different case (should work on case-insensitive file systems like Windows/macOS) + var method = typeof(HandlebarMailRenderer).GetMethod("ReadSourceFromDiskAsync", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var task = (Task)method!.Invoke(renderer, new object[] { templateFileName })!; + var result = await task; + + // Should successfully read the file + Assert.Equal("Test Content", result); + + // Verify no warning was logged + logger.DidNotReceive().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any>()); + } + finally + { + // Cleanup + if (Directory.Exists(tempDir)) + { + Directory.Delete(tempDir, true); + } + } + } +} diff --git a/test/Core.Test/Platform/Mailer/MailerTest.cs b/test/Core.Test/Platform/Mailer/MailerTest.cs new file mode 100644 index 0000000000..ca9cb2a874 --- /dev/null +++ b/test/Core.Test/Platform/Mailer/MailerTest.cs @@ -0,0 +1,42 @@ +using Bit.Core.Models.Mail; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Mailer; +using Bit.Core.Settings; +using Bit.Core.Test.Platform.Mailer.TestMail; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Platform.Mailer; + +public class MailerTest +{ + [Fact] + public async Task SendEmailAsync() + { + var logger = Substitute.For>(); + var globalSettings = new GlobalSettings { SelfHosted = false }; + var deliveryService = Substitute.For(); + + var mailer = new Core.Platform.Mail.Mailer.Mailer(new HandlebarMailRenderer(logger, globalSettings), deliveryService); + + var mail = new TestMail.TestMail() + { + ToEmails = ["test@bw.com"], + View = new TestMailView() { Name = "John Smith" } + }; + + MailMessage? sentMessage = null; + await deliveryService.SendEmailAsync(Arg.Do(message => + sentMessage = message + )); + + await mailer.SendEmail(mail); + + Assert.NotNull(sentMessage); + Assert.Contains("test@bw.com", sentMessage.ToEmails); + Assert.Equal("Test Email", sentMessage.Subject); + Assert.Equivalent("Hello John Smith", sentMessage.TextContent.Trim()); + Assert.Equivalent("Hello John Smith", sentMessage.HtmlContent.Trim()); + } +} diff --git a/test/Core.Test/Platform/Mailer/TestMail/TestMailView.cs b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.cs new file mode 100644 index 0000000000..e1b98f87d3 --- /dev/null +++ b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.cs @@ -0,0 +1,13 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Test.Platform.Mailer.TestMail; + +public class TestMailView : BaseMailView +{ + public required string Name { get; init; } +} + +public class TestMail : BaseMail +{ + public override string Subject { get; } = "Test Email"; +} diff --git a/test/Core.Test/Platform/Mailer/TestMail/TestMailView.html.hbs b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.html.hbs new file mode 100644 index 0000000000..c80512793e --- /dev/null +++ b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.html.hbs @@ -0,0 +1 @@ +Hello {{ Name }} diff --git a/test/Core.Test/Platform/Mailer/TestMail/TestMailView.text.hbs b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.text.hbs new file mode 100644 index 0000000000..a1a5777674 --- /dev/null +++ b/test/Core.Test/Platform/Mailer/TestMail/TestMailView.text.hbs @@ -0,0 +1 @@ +Hello {{ Name }} diff --git a/test/Core.Test/Platform/Push/Engines/AzureQueuePushEngineTests.cs b/test/Core.Test/Platform/Push/Engines/AzureQueuePushEngineTests.cs index 9c46211517..3f31f1fad4 100644 --- a/test/Core.Test/Platform/Push/Engines/AzureQueuePushEngineTests.cs +++ b/test/Core.Test/Platform/Push/Engines/AzureQueuePushEngineTests.cs @@ -358,20 +358,28 @@ public class AzureQueuePushEngineTests } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext) + [InlineData(true, null)] + [InlineData(true, PushNotificationLogOutReason.KdfChange)] + [InlineData(false, null)] + [InlineData(false, PushNotificationLogOutReason.KdfChange)] + public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext, + PushNotificationLogOutReason? reason) { var userId = Guid.NewGuid(); + var payload = new JsonObject + { + ["UserId"] = userId + }; + if (reason != null) + { + payload["Reason"] = (int)reason; + } + var expectedPayload = new JsonObject { ["Type"] = 11, - ["Payload"] = new JsonObject - { - ["UserId"] = userId, - ["Date"] = _fakeTimeProvider.GetUtcNow().UtcDateTime, - }, + ["Payload"] = payload, }; if (excludeCurrentContext) @@ -380,7 +388,7 @@ public class AzureQueuePushEngineTests } await VerifyNotificationAsync( - async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext), + async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext, reason), expectedPayload ); } diff --git a/test/Core.Test/Platform/Push/Engines/NotificationsApiPushEngineTests.cs b/test/Core.Test/Platform/Push/Engines/NotificationsApiPushEngineTests.cs index c61c2f37d0..7f230c4e5c 100644 --- a/test/Core.Test/Platform/Push/Engines/NotificationsApiPushEngineTests.cs +++ b/test/Core.Test/Platform/Push/Engines/NotificationsApiPushEngineTests.cs @@ -1,6 +1,7 @@ using System.Text.Json.Nodes; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; +using Bit.Core.Enums; using Bit.Core.NotificationCenter.Entities; using Bit.Core.Platform.Push.Internal; using Bit.Core.Tools.Entities; @@ -193,7 +194,8 @@ public class NotificationsApiPushEngineTests : PushTestBase }; } - protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext) + protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext, + PushNotificationLogOutReason? reason) { JsonNode? contextId = excludeCurrentContext ? DeviceIdentifier : null; @@ -203,7 +205,7 @@ public class NotificationsApiPushEngineTests : PushTestBase ["Payload"] = new JsonObject { ["UserId"] = userId, - ["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime, + ["Reason"] = reason != null ? (int)reason : null }, ["ContextId"] = contextId, }; diff --git a/test/Core.Test/Platform/Push/Engines/PushTestBase.cs b/test/Core.Test/Platform/Push/Engines/PushTestBase.cs index e0eeeda97d..c0037f57aa 100644 --- a/test/Core.Test/Platform/Push/Engines/PushTestBase.cs +++ b/test/Core.Test/Platform/Push/Engines/PushTestBase.cs @@ -86,7 +86,8 @@ public abstract class PushTestBase protected abstract JsonNode GetPushSyncOrganizationsPayload(Guid userId); protected abstract JsonNode GetPushSyncOrgKeysPayload(Guid userId); protected abstract JsonNode GetPushSyncSettingsPayload(Guid userId); - protected abstract JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext); + protected abstract JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext, + PushNotificationLogOutReason? reason); protected abstract JsonNode GetPushSendCreatePayload(Send send); protected abstract JsonNode GetPushSendUpdatePayload(Send send); protected abstract JsonNode GetPushSendDeletePayload(Send send); @@ -263,15 +264,18 @@ public abstract class PushTestBase } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext) + [InlineData(true, null)] + [InlineData(true, PushNotificationLogOutReason.KdfChange)] + [InlineData(false, null)] + [InlineData(false, PushNotificationLogOutReason.KdfChange)] + public async Task PushLogOutAsync_SendsExpectedResponse(bool excludeCurrentContext, + PushNotificationLogOutReason? reason) { var userId = Guid.NewGuid(); await VerifyNotificationAsync( - async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext), - GetPushLogOutPayload(userId, excludeCurrentContext) + async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext, reason), + GetPushLogOutPayload(userId, excludeCurrentContext, reason) ); } diff --git a/test/Core.Test/Platform/Push/Engines/RelayPushEngineTests.cs b/test/Core.Test/Platform/Push/Engines/RelayPushEngineTests.cs index 010ad40d13..f8ae07f647 100644 --- a/test/Core.Test/Platform/Push/Engines/RelayPushEngineTests.cs +++ b/test/Core.Test/Platform/Push/Engines/RelayPushEngineTests.cs @@ -4,6 +4,7 @@ using System.Text.Json.Nodes; using Bit.Core.AdminConsole.Entities; using Bit.Core.Auth.Entities; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.NotificationCenter.Entities; using Bit.Core.Platform.Push.Internal; using Bit.Core.Repositories; @@ -64,7 +65,7 @@ public class RelayPushNotificationServiceTests : PushTestBase ["UserId"] = cipher.UserId, ["OrganizationId"] = null, // Currently CollectionIds are not passed along from the method signature - // to the request body. + // to the request body. ["CollectionIds"] = null, ["RevisionDate"] = cipher.RevisionDate, }, @@ -88,7 +89,7 @@ public class RelayPushNotificationServiceTests : PushTestBase ["UserId"] = cipher.UserId, ["OrganizationId"] = null, // Currently CollectionIds are not passed along from the method signature - // to the request body. + // to the request body. ["CollectionIds"] = null, ["RevisionDate"] = cipher.RevisionDate, }, @@ -274,7 +275,8 @@ public class RelayPushNotificationServiceTests : PushTestBase }; } - protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext) + protected override JsonNode GetPushLogOutPayload(Guid userId, bool excludeCurrentContext, + PushNotificationLogOutReason? reason) { JsonNode? identifier = excludeCurrentContext ? DeviceIdentifier : null; @@ -288,7 +290,7 @@ public class RelayPushNotificationServiceTests : PushTestBase ["Payload"] = new JsonObject { ["UserId"] = userId, - ["Date"] = FakeTimeProvider.GetUtcNow().UtcDateTime, + ["Reason"] = reason != null ? (int)reason : null }, ["ClientType"] = null, ["InstallationId"] = null, diff --git a/test/Core.Test/Platform/Push/NotificationHub/NotificationHubPushEngineTests.cs b/test/Core.Test/Platform/Push/NotificationHub/NotificationHubPushEngineTests.cs index a32b112675..f5f257c741 100644 --- a/test/Core.Test/Platform/Push/NotificationHub/NotificationHubPushEngineTests.cs +++ b/test/Core.Test/Platform/Push/NotificationHub/NotificationHubPushEngineTests.cs @@ -404,16 +404,18 @@ public class NotificationHubPushNotificationServiceTests } [Theory] - [InlineData(true)] - [InlineData(false)] - public async Task PushLogOutAsync_SendExpectedData(bool excludeCurrentContext) + [InlineData(true, null)] + [InlineData(true, PushNotificationLogOutReason.KdfChange)] + [InlineData(false, null)] + [InlineData(false, PushNotificationLogOutReason.KdfChange)] + public async Task PushLogOutAsync_SendExpectedData(bool excludeCurrentContext, PushNotificationLogOutReason? reason) { var userId = Guid.NewGuid(); var expectedPayload = new JsonObject { ["UserId"] = userId, - ["Date"] = _now, + ["Reason"] = reason != null ? (int)reason : null, }; var expectedTag = excludeCurrentContext @@ -421,7 +423,7 @@ public class NotificationHubPushNotificationServiceTests : $"(template:payload_userId:{userId})"; await VerifyNotificationAsync( - async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext), + async sut => await sut.PushLogOutAsync(userId, excludeCurrentContext, reason), PushType.LogOut, expectedPayload, expectedTag diff --git a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs index 71bbc9f13e..99d967dc57 100644 --- a/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs +++ b/test/Core.Test/Services/AmazonSesMailDeliveryServiceTests.cs @@ -1,7 +1,7 @@ using Amazon.SimpleEmail; using Amazon.SimpleEmail.Model; using Bit.Core.Models.Mail; -using Bit.Core.Services; +using Bit.Core.Platform.Mail.Delivery; using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index 30eebfb30f..b98c4580f5 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -6,7 +6,10 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Business; using Bit.Core.Entities; using Bit.Core.Models.Mail; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; using Bit.Core.Services; +using Bit.Core.Services.Mail; using Bit.Core.Settings; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Logging; @@ -265,4 +268,115 @@ public class HandlebarsMailServiceTests // Assert await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any()); } + + [Fact] + public async Task SendIndividualUserWelcomeEmailAsync_SendsCorrectEmail() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com" + }; + + // Act + await _sut.SendIndividualUserWelcomeEmailAsync(user); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("test@example.com") && + m.Subject == "Welcome to Bitwarden!" && + m.Category == "Welcome")); + } + + [Fact] + public async Task SendOrganizationUserWelcomeEmailAsync_SendsCorrectEmailWithOrganizationName() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "user@company.com" + }; + var organizationName = "Bitwarden Corp"; + + // Act + await _sut.SendOrganizationUserWelcomeEmailAsync(user, organizationName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("user@company.com") && + m.Subject == "Welcome to Bitwarden!" && + m.HtmlContent.Contains("Bitwarden Corp") && + m.Category == "Welcome")); + } + + [Fact] + public async Task SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync_SendsCorrectEmailWithFamilyTemplate() + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "family@example.com" + }; + var familyOrganizationName = "Smith Family"; + + // Act + await _sut.SendFreeOrgOrFamilyOrgUserWelcomeEmailAsync(user, familyOrganizationName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.MetaData != null && + m.ToEmails.Contains("family@example.com") && + m.Subject == "Welcome to Bitwarden!" && + m.HtmlContent.Contains("Smith Family") && + m.Category == "Welcome")); + } + + [Theory] + [InlineData("Acme Corp", "Acme Corp")] + [InlineData("Company & Associates", "Company & Associates")] + [InlineData("Test \"Quoted\" Org", "Test "Quoted" Org")] + public async Task SendOrganizationUserWelcomeEmailAsync_SanitizesOrganizationNameForEmail(string inputOrgName, string expectedSanitized) + { + // Arrange + var user = new User + { + Id = Guid.NewGuid(), + Email = "test@example.com" + }; + + // Act + await _sut.SendOrganizationUserWelcomeEmailAsync(user, inputOrgName); + + // Assert + await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Is(m => + m.HtmlContent.Contains(expectedSanitized) && + !m.HtmlContent.Contains("