diff --git a/CLAUDE.md b/.claude/CLAUDE.md similarity index 86% rename from CLAUDE.md rename to .claude/CLAUDE.md index d07bd3f3e1..c1349e8c9d 100644 --- a/CLAUDE.md +++ b/.claude/CLAUDE.md @@ -1,24 +1,29 @@ # Bitwarden Server - Claude Code Configuration +## Project Context Files + +**Read these files before reviewing to ensure that you fully understand the project and contributing guidelines** + +1. @README.md +2. @CONTRIBUTING.md +3. @.github/PULL_REQUEST_TEMPLATE.md + ## Critical Rules -- **NEVER** edit: `/bin/`, `/obj/`, `/.git/`, `/.vs/`, `/packages/` which are generated files - **NEVER** use code regions: If complexity suggests regions, refactor for better readability + - **NEVER** compromise zero-knowledge principles: User vault data must remain encrypted and inaccessible to Bitwarden + - **NEVER** log or expose sensitive data: No PII, passwords, keys, or vault data in logs or error messages + - **ALWAYS** use secure communication channels: Enforce confidentiality, integrity, and authenticity + - **ALWAYS** encrypt sensitive data: All vault data must be encrypted at rest, in transit, and in use + - **ALWAYS** prioritize cryptographic integrity and data protection + - **ALWAYS** add unit tests (with mocking) for any new feature development -## Project Context - -- **Architecture**: Feature and team-based organization -- **Framework**: .NET 8.0, ASP.NET Core -- **Database**: SQL Server primary, EF Core supports PostgreSQL, MySQL/MariaDB, SQLite -- **Testing**: xUnit, NSubstitute -- **Container**: Docker, Docker Compose, Kubernetes/Helm deployable - ## Project Structure - **Source Code**: `/src/` - Services and core infrastructure @@ -42,7 +47,7 @@ - **Database update**: `pwsh dev/migrate.ps1` - **Generate OpenAPI**: `pwsh dev/generate_openapi_files.ps1` -## Code Review Checklist +## Development Workflow - Security impact assessed - xUnit tests added / updated diff --git a/.claude/prompts/review-code.md b/.claude/prompts/review-code.md new file mode 100644 index 0000000000..4e5f40b274 --- /dev/null +++ b/.claude/prompts/review-code.md @@ -0,0 +1,25 @@ +Please review this pull request with a focus on: + +- Code quality and best practices +- Potential bugs or issues +- Security implications +- Performance considerations + +Note: The PR branch is already checked out in the current working directory. + +Provide a comprehensive review including: + +- Summary of changes since last review +- Critical issues found (be thorough) +- Suggested improvements (be thorough) +- Good practices observed (be concise - list only the most notable items without elaboration) +- Action items for the author +- Leverage collapsible
sections where appropriate for lengthy explanations or code snippets to enhance human readability + +When reviewing subsequent commits: + +- Track status of previously identified issues (fixed/unfixed/reopened) +- Identify NEW problems introduced since last review +- Note if fixes introduced new issues + +IMPORTANT: Be comprehensive about issues and improvements. For good practices, be brief - just note what was done well without explaining why or praising excessively. diff --git a/.editorconfig b/.editorconfig index 21d7ac4a3a..fd68808456 100644 --- a/.editorconfig +++ b/.editorconfig @@ -123,3 +123,12 @@ csharp_style_namespace_declarations = file_scoped:warning # Switch expression dotnet_diagnostic.CS8509.severity = error # missing switch case for named enum value dotnet_diagnostic.CS8524.severity = none # missing switch case for unnamed enum value + +# CA2253: Named placeholders should nto be numeric values +dotnet_diagnostic.CA2253.severity = suggestion + +# CA2254: Template should be a static expression +dotnet_diagnostic.CA2254.severity = warning + +# CA1727: Use PascalCase for named placeholders +dotnet_diagnostic.CA1727.severity = suggestion diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 2f1c5f18fb..65780bdb63 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,11 +4,12 @@ # # https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners -## Docker files have shared ownership ## -**/Dockerfile -**/*.Dockerfile -**/.dockerignore -**/entrypoint.sh +## Docker-related files +**/Dockerfile @bitwarden/team-appsec @bitwarden/dept-bre +**/*.Dockerfile @bitwarden/team-appsec @bitwarden/dept-bre +**/*.dockerignore @bitwarden/team-appsec @bitwarden/dept-bre +**/docker-compose.yml @bitwarden/team-appsec @bitwarden/dept-bre +**/entrypoint.sh @bitwarden/team-appsec @bitwarden/dept-bre ## BRE team owns these workflows ## .github/workflows/publish.yml @bitwarden/dept-bre @@ -95,6 +96,14 @@ src/Admin/Views/Tools @bitwarden/team-billing-dev # The PushType enum is expected to be editted by anyone without need for Platform review src/Core/Platform/Push/PushType.cs +# SDK +util/RustSdk @bitwarden/team-sdk-sme + # Multiple owners - DO NOT REMOVE (BRE) **/packages.lock.json Directory.Build.props + +# Claude related files +.claude/ @bitwarden/team-ai-sme +.github/workflows/respond.yml @bitwarden/team-ai-sme +.github/workflows/review-code.yml @bitwarden/team-ai-sme diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 5c01832c06..bc377ed46c 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -2,6 +2,7 @@ $schema: "https://docs.renovatebot.com/renovate-schema.json", extends: ["github>bitwarden/renovate-config"], // Extends our default configuration for pinned dependencies enabledManagers: [ + "cargo", "dockerfile", "docker-compose", "github-actions", @@ -9,6 +10,11 @@ "nuget", ], packageRules: [ + { + groupName: "cargo minor", + matchManagers: ["cargo"], + matchUpdateTypes: ["minor"], + }, { groupName: "dockerfile minor", matchManagers: ["dockerfile"], @@ -35,6 +41,10 @@ matchUpdateTypes: ["patch"], dependencyDashboardApproval: false, }, + { + matchSourceUrls: ["https://github.com/bitwarden/sdk-internal"], + groupName: "sdk-internal", + }, { matchManagers: ["dockerfile", "docker-compose"], commitMessagePrefix: "[deps] BRE:", diff --git a/.github/workflows/_move_edd_db_scripts.yml b/.github/workflows/_move_edd_db_scripts.yml index b38a3e0dff..7e97fa2a07 100644 --- a/.github/workflows/_move_edd_db_scripts.yml +++ b/.github/workflows/_move_edd_db_scripts.yml @@ -41,18 +41,19 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ steps.retrieve-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + persist-credentials: false - name: Get script prefix id: prefix - run: echo "prefix=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT + run: echo "prefix=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" - name: Check if any files in DB transition or finalization directories id: check-script-existence run: | if [ -f util/Migrator/DbScripts_transition/* -o -f util/Migrator/DbScripts_finalization/* ]; then - echo "copy_edd_scripts=true" >> $GITHUB_OUTPUT + echo "copy_edd_scripts=true" >> "$GITHUB_OUTPUT" else - echo "copy_edd_scripts=false" >> $GITHUB_OUTPUT + echo "copy_edd_scripts=false" >> "$GITHUB_OUTPUT" fi move-scripts: @@ -70,17 +71,18 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 + persist-credentials: true - name: Generate branch name id: branch_name env: PREFIX: ${{ needs.setup.outputs.migration_filename_prefix }} - run: echo "branch_name=move_edd_db_scripts_$PREFIX" >> $GITHUB_OUTPUT + run: echo "branch_name=move_edd_db_scripts_$PREFIX" >> "$GITHUB_OUTPUT" - name: "Create branch" env: BRANCH: ${{ steps.branch_name.outputs.branch_name }} - run: git switch -c $BRANCH + run: git switch -c "$BRANCH" - name: Move scripts and finalization database schema id: move-files @@ -120,7 +122,7 @@ jobs: # sync finalization schema back to dbo, maintaining structure rsync -r "$src_dir/" "$dest_dir/" - rm -rf $src_dir/* + rm -rf "${src_dir}"/* # Replace any finalization references due to the move find ./src/Sql/dbo -name "*.sql" -type f -exec sed -i \ @@ -131,7 +133,7 @@ jobs: moved_files="$moved_files \n $file" done - echo "moved_files=$moved_files" >> $GITHUB_OUTPUT + echo "moved_files=$moved_files" >> "$GITHUB_OUTPUT" - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -162,18 +164,20 @@ jobs: - name: Commit and push changes id: commit + env: + BRANCH_NAME: ${{ steps.branch_name.outputs.branch_name }} run: | git config --local user.email "106330231+bitwarden-devops-bot@users.noreply.github.com" git config --local user.name "bitwarden-devops-bot" if [ -n "$(git status --porcelain)" ]; then git add . git commit -m "Move EDD database scripts" -a - git push -u origin ${{ steps.branch_name.outputs.branch_name }} - echo "pr_needed=true" >> $GITHUB_OUTPUT + git push -u origin "${BRANCH_NAME}" + echo "pr_needed=true" >> "$GITHUB_OUTPUT" else echo "No changes to commit!"; - echo "pr_needed=false" >> $GITHUB_OUTPUT - echo "### :mega: No changes to commit! PR was ommited." >> $GITHUB_STEP_SUMMARY + echo "pr_needed=false" >> "$GITHUB_OUTPUT" + echo "### :mega: No changes to commit! PR was ommited." >> "$GITHUB_STEP_SUMMARY" fi - name: Create PR for ${{ steps.branch_name.outputs.branch_name }} @@ -195,7 +199,7 @@ jobs: Files moved: $(echo -e "$MOVED_FILES") ") - echo "pr_url=${PR_URL}" >> $GITHUB_OUTPUT + echo "pr_url=${PR_URL}" >> "$GITHUB_OUTPUT" - name: Notify Slack about creation of PR if: ${{ steps.commit.outputs.pr_needed == 'true' }} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 30fcf29206..2d92c68b93 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,9 +28,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Verify format run: dotnet format --verify-no-changes @@ -45,8 +46,10 @@ jobs: permissions: security-events: write id-token: write + timeout-minutes: 45 strategy: fail-fast: false + max-parallel: 5 matrix: include: - project_name: Admin @@ -97,30 +100,31 @@ jobs: id: check-secrets run: | has_secrets=${{ secrets.AZURE_CLIENT_ID != '' }} - echo "has_secrets=$has_secrets" >> $GITHUB_OUTPUT + echo "has_secrets=$has_secrets" >> "$GITHUB_OUTPUT" - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Check branch to publish env: PUBLISH_BRANCHES: "main,rc,hotfix-rc" id: publish-branch-check run: | - IFS="," read -a publish_branches <<< $PUBLISH_BRANCHES + IFS="," read -a publish_branches <<< "$PUBLISH_BRANCHES" if [[ " ${publish_branches[*]} " =~ " ${GITHUB_REF:11} " ]]; then - echo "is_publish_branch=true" >> $GITHUB_ENV + echo "is_publish_branch=true" >> "$GITHUB_ENV" else - echo "is_publish_branch=false" >> $GITHUB_ENV + echo "is_publish_branch=false" >> "$GITHUB_ENV" fi - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Set up Node - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 with: cache: "npm" cache-dependency-path: "**/package-lock.json" @@ -157,7 +161,7 @@ jobs: ls -atlh ../../../ - name: Upload project artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 if: ${{ matrix.dotnet }} with: name: ${{ matrix.project_name }}.zip @@ -209,8 +213,8 @@ jobs: IMAGE_TAG=dev fi - echo "image_tag=$IMAGE_TAG" >> $GITHUB_OUTPUT - echo "### :mega: Docker Image Tag: $IMAGE_TAG" >> $GITHUB_STEP_SUMMARY + echo "image_tag=$IMAGE_TAG" >> "$GITHUB_OUTPUT" + echo "### :mega: Docker Image Tag: $IMAGE_TAG" >> "$GITHUB_STEP_SUMMARY" - name: Set up project name id: setup @@ -218,7 +222,7 @@ jobs: PROJECT_NAME=$(echo "${{ matrix.project_name }}" | awk '{print tolower($0)}') echo "Matrix name: ${{ matrix.project_name }}" echo "PROJECT_NAME: $PROJECT_NAME" - echo "project_name=$PROJECT_NAME" >> $GITHUB_OUTPUT + echo "project_name=$PROJECT_NAME" >> "$GITHUB_OUTPUT" - name: Generate image tags(s) id: image-tags @@ -228,12 +232,12 @@ jobs: SHA: ${{ github.sha }} run: | TAGS="${_AZ_REGISTRY}/${PROJECT_NAME}:${IMAGE_TAG}" - echo "primary_tag=$TAGS" >> $GITHUB_OUTPUT + echo "primary_tag=$TAGS" >> "$GITHUB_OUTPUT" if [[ "${IMAGE_TAG}" == "dev" ]]; then - SHORT_SHA=$(git rev-parse --short ${SHA}) + SHORT_SHA=$(git rev-parse --short "${SHA}") TAGS=$TAGS",${_AZ_REGISTRY}/${PROJECT_NAME}:dev-${SHORT_SHA}" fi - echo "tags=$TAGS" >> $GITHUB_OUTPUT + echo "tags=$TAGS" >> "$GITHUB_OUTPUT" - name: Build Docker image id: build-artifacts @@ -260,16 +264,17 @@ jobs: DIGEST: ${{ steps.build-artifacts.outputs.digest }} TAGS: ${{ steps.image-tags.outputs.tags }} run: | - IFS="," read -a tags <<< "${TAGS}" - images="" - for tag in "${tags[@]}"; do - images+="${tag}@${DIGEST} " + IFS=',' read -r -a tags_array <<< "${TAGS}" + images=() + for tag in "${tags_array[@]}"; do + images+=("${tag}@${DIGEST}") done - cosign sign --yes ${images} + cosign sign --yes ${images[@]} + echo "images=${images[*]}" >> "$GITHUB_OUTPUT" - name: Scan Docker image id: container-scan - uses: anchore/scan-action@2c901ab7378897c01b8efaa2d0c9bf519cc64b9e # v6.2.0 + uses: anchore/scan-action@f6601287cdb1efc985d6b765bbf99cb4c0ac29d8 # v7.0.0 with: image: ${{ steps.image-tags.outputs.primary_tag }} fail-build: false @@ -297,9 +302,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -309,7 +315,7 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to ACR - production subscription - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors - name: Make Docker stubs if: | @@ -332,26 +338,26 @@ jobs: STUB_OUTPUT=$(pwd)/docker-stub # Run setup - docker run -i --rm --name setup -v $STUB_OUTPUT/US:/bitwarden $SETUP_IMAGE \ + docker run -i --rm --name setup -v "$STUB_OUTPUT/US:/bitwarden" "$SETUP_IMAGE" \ /app/Setup -stub 1 -install 1 -domain bitwarden.example.com -os lin -cloud-region US - docker run -i --rm --name setup -v $STUB_OUTPUT/EU:/bitwarden $SETUP_IMAGE \ + docker run -i --rm --name setup -v "$STUB_OUTPUT/EU:/bitwarden" "$SETUP_IMAGE" \ /app/Setup -stub 1 -install 1 -domain bitwarden.example.com -os lin -cloud-region EU - sudo chown -R $(whoami):$(whoami) $STUB_OUTPUT + sudo chown -R "$(whoami):$(whoami)" "$STUB_OUTPUT" # Remove extra directories and files - rm -rf $STUB_OUTPUT/US/letsencrypt - rm -rf $STUB_OUTPUT/EU/letsencrypt - rm $STUB_OUTPUT/US/env/uid.env $STUB_OUTPUT/US/config.yml - rm $STUB_OUTPUT/EU/env/uid.env $STUB_OUTPUT/EU/config.yml + rm -rf "$STUB_OUTPUT/US/letsencrypt" + rm -rf "$STUB_OUTPUT/EU/letsencrypt" + rm "$STUB_OUTPUT/US/env/uid.env" "$STUB_OUTPUT/US/config.yml" + rm "$STUB_OUTPUT/EU/env/uid.env" "$STUB_OUTPUT/EU/config.yml" # Create uid environment files - touch $STUB_OUTPUT/US/env/uid.env - touch $STUB_OUTPUT/EU/env/uid.env + touch "$STUB_OUTPUT/US/env/uid.env" + touch "$STUB_OUTPUT/EU/env/uid.env" # Zip up the Docker stub files - cd docker-stub/US; zip -r ../../docker-stub-US.zip *; cd ../.. - cd docker-stub/EU; zip -r ../../docker-stub-EU.zip *; cd ../.. + cd docker-stub/US; zip -r ../../docker-stub-US.zip ./*; cd ../.. + cd docker-stub/EU; zip -r ../../docker-stub-EU.zip ./*; cd ../.. - name: Log out from Azure uses: bitwarden/gh-actions/azure-logout@main @@ -360,7 +366,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: docker-stub-US.zip path: docker-stub-US.zip @@ -370,7 +376,7 @@ jobs: if: | github.event_name != 'pull_request' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: docker-stub-EU.zip path: docker-stub-EU.zip @@ -382,21 +388,21 @@ jobs: pwsh ./generate_openapi_files.ps1 - name: Upload Public API Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: swagger.json path: api.public.json if-no-files-found: error - name: Upload Internal API Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: internal.json path: api.json if-no-files-found: error - name: Upload Identity Swagger artifact - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: identity.json path: identity.json @@ -423,9 +429,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Print environment run: | @@ -441,7 +448,7 @@ jobs: - name: Upload project artifact for Windows if: ${{ contains(matrix.target, 'win') == true }} - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: MsSqlMigratorUtility-${{ matrix.target }} path: util/MsSqlMigratorUtility/obj/build-output/publish/MsSqlMigratorUtility.exe @@ -449,7 +456,7 @@ jobs: - name: Upload project artifact if: ${{ contains(matrix.target, 'win') == false }} - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: MsSqlMigratorUtility-${{ matrix.target }} path: util/MsSqlMigratorUtility/obj/build-output/publish/MsSqlMigratorUtility @@ -484,7 +491,7 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Trigger self-host build - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} script: | @@ -525,7 +532,7 @@ jobs: uses: bitwarden/gh-actions/azure-logout@main - name: Trigger k8s deploy - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ steps.retrieve-secret-pat.outputs.github-pat-bitwarden-devops-bot-repo-scope }} script: | diff --git a/.github/workflows/cleanup-after-pr.yml b/.github/workflows/cleanup-after-pr.yml index e39bf8ea3a..4e59f1fa96 100644 --- a/.github/workflows/cleanup-after-pr.yml +++ b/.github/workflows/cleanup-after-pr.yml @@ -22,7 +22,7 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to Azure ACR - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors ########## Remove Docker images ########## - name: Remove the Docker image from ACR @@ -45,20 +45,20 @@ jobs: - Setup - Sso run: | - for SERVICE in $(echo "${{ env.SERVICES }}" | yq e ".services[]" - ) + for SERVICE in $(echo "${SERVICES}" | yq e ".services[]" - ) do - SERVICE_NAME=$(echo $SERVICE | awk '{print tolower($0)}') + SERVICE_NAME=$(echo "$SERVICE" | awk '{print tolower($0)}') IMAGE_TAG=$(echo "${REF}" | sed "s#/#-#g") # slash safe branch name echo "[*] Checking if remote exists: $_AZ_REGISTRY/$SERVICE_NAME:$IMAGE_TAG" TAG_EXISTS=$( - az acr repository show-tags --name $_AZ_REGISTRY --repository $SERVICE_NAME \ - | jq --arg $TAG "$IMAGE_TAG" -e '. | any(. == "$TAG")' + az acr repository show-tags --name "$_AZ_REGISTRY" --repository "$SERVICE_NAME" \ + | jq --arg TAG "$IMAGE_TAG" -e '. | any(. == $TAG)' ) if [[ "$TAG_EXISTS" == "true" ]]; then echo "[*] Tag exists. Removing tag" - az acr repository delete --name $_AZ_REGISTRY --image $SERVICE_NAME:$IMAGE_TAG --yes + az acr repository delete --name "$_AZ_REGISTRY" --image "$SERVICE_NAME:$IMAGE_TAG" --yes else echo "[*] Tag does not exist. No action needed" fi diff --git a/.github/workflows/cleanup-rc-branch.yml b/.github/workflows/cleanup-rc-branch.yml index 5c74284423..63079826c7 100644 --- a/.github/workflows/cleanup-rc-branch.yml +++ b/.github/workflows/cleanup-rc-branch.yml @@ -35,6 +35,8 @@ jobs: with: ref: main token: ${{ steps.retrieve-bot-secrets.outputs.github-pat-bitwarden-devops-bot-repo-scope }} + persist-credentials: false + fetch-depth: 0 - name: Check if a RC branch exists id: branch-check @@ -43,11 +45,11 @@ jobs: rc_branch_check=$(git ls-remote --heads origin rc | wc -l) if [[ "${hotfix_rc_branch_check}" -gt 0 ]]; then - echo "hotfix-rc branch exists." | tee -a $GITHUB_STEP_SUMMARY - echo "name=hotfix-rc" >> $GITHUB_OUTPUT + echo "hotfix-rc branch exists." | tee -a "$GITHUB_STEP_SUMMARY" + echo "name=hotfix-rc" >> "$GITHUB_OUTPUT" elif [[ "${rc_branch_check}" -gt 0 ]]; then - echo "rc branch exists." | tee -a $GITHUB_STEP_SUMMARY - echo "name=rc" >> $GITHUB_OUTPUT + echo "rc branch exists." | tee -a "$GITHUB_STEP_SUMMARY" + echo "name=rc" >> "$GITHUB_OUTPUT" fi - name: Delete RC branch @@ -55,6 +57,6 @@ jobs: BRANCH_NAME: ${{ steps.branch-check.outputs.name }} run: | if ! [[ -z "$BRANCH_NAME" ]]; then - git push --quiet origin --delete $BRANCH_NAME - echo "Deleted $BRANCH_NAME branch." | tee -a $GITHUB_STEP_SUMMARY + git push --quiet origin --delete "$BRANCH_NAME" + echo "Deleted $BRANCH_NAME branch." | tee -a "$GITHUB_STEP_SUMMARY" fi diff --git a/.github/workflows/code-references.yml b/.github/workflows/code-references.yml index 75e0c43306..35e6cfdd40 100644 --- a/.github/workflows/code-references.yml +++ b/.github/workflows/code-references.yml @@ -19,9 +19,9 @@ jobs: id: check-secret-access run: | if [ "${{ secrets.AZURE_CLIENT_ID }}" != '' ]; then - echo "available=true" >> $GITHUB_OUTPUT; + echo "available=true" >> "$GITHUB_OUTPUT"; else - echo "available=false" >> $GITHUB_OUTPUT; + echo "available=false" >> "$GITHUB_OUTPUT"; fi refs: @@ -37,6 +37,8 @@ jobs: steps: - name: Check out repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -65,14 +67,14 @@ jobs: - name: Add label if: steps.collect.outputs.any-changed == 'true' - run: gh pr edit $PR_NUMBER --add-label feature-flag + run: gh pr edit "$PR_NUMBER" --add-label feature-flag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} - name: Remove label if: steps.collect.outputs.any-changed == 'false' - run: gh pr edit $PR_NUMBER --remove-label feature-flag + run: gh pr edit "$PR_NUMBER" --remove-label feature-flag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} diff --git a/.github/workflows/enforce-labels.yml b/.github/workflows/enforce-labels.yml index 353127c751..1759b29787 100644 --- a/.github/workflows/enforce-labels.yml +++ b/.github/workflows/enforce-labels.yml @@ -17,5 +17,5 @@ jobs: - name: Check for label run: | echo "PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" - echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> $GITHUB_STEP_SUMMARY + echo "### :x: PRs with the hold, needs-qa or ephemeral-environment labels cannot be merged" >> "$GITHUB_STEP_SUMMARY" exit 1 diff --git a/.github/workflows/ephemeral-environment.yml b/.github/workflows/ephemeral-environment.yml index d85fcf2fd4..456ca573cc 100644 --- a/.github/workflows/ephemeral-environment.yml +++ b/.github/workflows/ephemeral-environment.yml @@ -16,5 +16,5 @@ jobs: with: project: server pull_request_number: ${{ github.event.number }} - sync_environment: true + sync_environment: false secrets: inherit diff --git a/.github/workflows/load-test.yml b/.github/workflows/load-test.yml index 9bc6da89e7..cdb53109f5 100644 --- a/.github/workflows/load-test.yml +++ b/.github/workflows/load-test.yml @@ -63,13 +63,15 @@ jobs: # Datadog agent for collecting OTEL metrics from k6 - name: Start Datadog agent + env: + DD_API_KEY: ${{ steps.get-kv-secrets.outputs.DD-API-KEY }} run: | docker run --detach \ --name datadog-agent \ -p 4317:4317 \ -p 5555:5555 \ -e DD_SITE=us3.datadoghq.com \ - -e DD_API_KEY=${{ steps.get-kv-secrets.outputs.DD-API-KEY }} \ + -e DD_API_KEY="${DD_API_KEY}" \ -e DD_DOGSTATSD_NON_LOCAL_TRAFFIC=1 \ -e DD_OTLP_CONFIG_RECEIVER_PROTOCOLS_GRPC_ENDPOINT=0.0.0.0:4317 \ -e DD_HEALTH_PORT=5555 \ diff --git a/.github/workflows/protect-files.yml b/.github/workflows/protect-files.yml index 546b8344a6..a939be6fdb 100644 --- a/.github/workflows/protect-files.yml +++ b/.github/workflows/protect-files.yml @@ -34,6 +34,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 2 + persist-credentials: false - name: Check for file changes id: check-changes @@ -43,9 +44,9 @@ jobs: for file in $MODIFIED_FILES do if [[ $file == *"${{ matrix.path }}"* ]]; then - echo "changes_detected=true" >> $GITHUB_OUTPUT + echo "changes_detected=true" >> "$GITHUB_OUTPUT" break - else echo "changes_detected=false" >> $GITHUB_OUTPUT + else echo "changes_detected=false" >> "$GITHUB_OUTPUT" fi done diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 444c2289d1..2272387d84 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -36,21 +36,23 @@ jobs: steps: - name: Version output id: version-output + env: + INPUT_VERSION: ${{ inputs.version }} run: | - if [[ "${{ inputs.version }}" == "latest" || "${{ inputs.version }}" == "" ]]; then + if [[ "${INPUT_VERSION}" == "latest" || "${INPUT_VERSION}" == "" ]]; then VERSION=$(curl "https://api.github.com/repos/bitwarden/server/releases" | jq -c '.[] | select(.tag_name) | .tag_name' | head -1 | grep -ohE '20[0-9]{2}\.([1-9]|1[0-2])\.[0-9]+') echo "Latest Released Version: $VERSION" - echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "version=$VERSION" >> "$GITHUB_OUTPUT" else - echo "Release Version: ${{ inputs.version }}" - echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + echo "Release Version: ${INPUT_VERSION}" + echo "version=${INPUT_VERSION}" >> "$GITHUB_OUTPUT" fi - name: Get branch name id: branch run: | - BRANCH_NAME=$(basename ${{ github.ref }}) - echo "branch-name=$BRANCH_NAME" >> $GITHUB_OUTPUT + BRANCH_NAME=$(basename "${GITHUB_REF}") + echo "branch-name=$BRANCH_NAME" >> "$GITHUB_OUTPUT" - name: Create GitHub deployment uses: chrnorm/deployment-action@55729fcebec3d284f60f5bcabbd8376437d696b1 # v2.0.7 @@ -105,6 +107,9 @@ jobs: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + persist-credentials: false - name: Set up project name id: setup @@ -112,7 +117,7 @@ jobs: PROJECT_NAME=$(echo "${{ matrix.project_name }}" | awk '{print tolower($0)}') echo "Matrix name: ${{ matrix.project_name }}" echo "PROJECT_NAME: $PROJECT_NAME" - echo "project_name=$PROJECT_NAME" >> $GITHUB_OUTPUT + echo "project_name=$PROJECT_NAME" >> "$GITHUB_OUTPUT" ########## ACR PROD ########## - name: Log in to Azure @@ -123,16 +128,16 @@ jobs: client_id: ${{ secrets.AZURE_CLIENT_ID }} - name: Log in to Azure ACR - run: az acr login -n $_AZ_REGISTRY --only-show-errors + run: az acr login -n "$_AZ_REGISTRY" --only-show-errors - name: Pull latest project image env: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker pull $_AZ_REGISTRY/$PROJECT_NAME:latest + docker pull "$_AZ_REGISTRY/$PROJECT_NAME:latest" else - docker pull $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME + docker pull "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" fi - name: Tag version and latest @@ -140,10 +145,10 @@ jobs: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker tag $_AZ_REGISTRY/$PROJECT_NAME:latest $_AZ_REGISTRY/$PROJECT_NAME:dryrun + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:latest" "$_AZ_REGISTRY/$PROJECT_NAME:dryrun" else - docker tag $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME $_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION - docker tag $_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME $_AZ_REGISTRY/$PROJECT_NAME:latest + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" "$_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION" + docker tag "$_AZ_REGISTRY/$PROJECT_NAME:$_BRANCH_NAME" "$_AZ_REGISTRY/$PROJECT_NAME:latest" fi - name: Push version and latest image @@ -151,10 +156,10 @@ jobs: PROJECT_NAME: ${{ steps.setup.outputs.project_name }} run: | if [[ "${{ inputs.publish_type }}" == "Dry Run" ]]; then - docker push $_AZ_REGISTRY/$PROJECT_NAME:dryrun + docker push "$_AZ_REGISTRY/$PROJECT_NAME:dryrun" else - docker push $_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION - docker push $_AZ_REGISTRY/$PROJECT_NAME:latest + docker push "$_AZ_REGISTRY/$PROJECT_NAME:$_RELEASE_VERSION" + docker push "$_AZ_REGISTRY/$PROJECT_NAME:latest" fi - name: Log out of Docker diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8bb19b4da1..75b4df4e5c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -40,6 +40,9 @@ jobs: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + persist-credentials: false - name: Check release version id: version @@ -52,8 +55,8 @@ jobs: - name: Get branch name id: branch run: | - BRANCH_NAME=$(basename ${{ github.ref }}) - echo "branch-name=$BRANCH_NAME" >> $GITHUB_OUTPUT + BRANCH_NAME=$(basename "${GITHUB_REF}") + echo "branch-name=$BRANCH_NAME" >> "$GITHUB_OUTPUT" release: name: Create GitHub release diff --git a/.github/workflows/repository-management.yml b/.github/workflows/repository-management.yml index 67e1d8a926..92452102cf 100644 --- a/.github/workflows/repository-management.yml +++ b/.github/workflows/repository-management.yml @@ -46,7 +46,7 @@ jobs: BRANCH="hotfix-rc" fi - echo "branch=$BRANCH" >> $GITHUB_OUTPUT + echo "branch=$BRANCH" >> "$GITHUB_OUTPUT" bump_version: name: Bump Version @@ -95,6 +95,7 @@ jobs: with: ref: main token: ${{ steps.app-token.outputs.token }} + persist-credentials: true - name: Configure Git run: | @@ -110,7 +111,7 @@ jobs: id: current-version run: | CURRENT_VERSION=$(xmllint -xpath "/Project/PropertyGroup/Version/text()" Directory.Build.props) - echo "version=$CURRENT_VERSION" >> $GITHUB_OUTPUT + echo "version=$CURRENT_VERSION" >> "$GITHUB_OUTPUT" - name: Verify input version if: ${{ inputs.version_number_override != '' }} @@ -120,16 +121,15 @@ jobs: run: | # Error if version has not changed. if [[ "$NEW_VERSION" == "$CURRENT_VERSION" ]]; then - echo "Specified override version is the same as the current version." >> $GITHUB_STEP_SUMMARY + echo "Specified override version is the same as the current version." >> "$GITHUB_STEP_SUMMARY" exit 1 fi # Check if version is newer. - printf '%s\n' "${CURRENT_VERSION}" "${NEW_VERSION}" | sort -C -V - if [ $? -eq 0 ]; then + if printf '%s\n' "${CURRENT_VERSION}" "${NEW_VERSION}" | sort -C -V; then echo "Version is newer than the current version." else - echo "Version is older than the current version." >> $GITHUB_STEP_SUMMARY + echo "Version is older than the current version." >> "$GITHUB_STEP_SUMMARY" exit 1 fi @@ -160,15 +160,20 @@ jobs: id: set-final-version-output env: VERSION: ${{ inputs.version_number_override }} + BUMP_VERSION_OVERRIDE_OUTCOME: ${{ steps.bump-version-override.outcome }} + BUMP_VERSION_AUTOMATIC_OUTCOME: ${{ steps.bump-version-automatic.outcome }} + CALCULATE_NEXT_VERSION: ${{ steps.calculate-next-version.outputs.version }} run: | - if [[ "${{ steps.bump-version-override.outcome }}" = "success" ]]; then - echo "version=$VERSION" >> $GITHUB_OUTPUT - elif [[ "${{ steps.bump-version-automatic.outcome }}" = "success" ]]; then - echo "version=${{ steps.calculate-next-version.outputs.version }}" >> $GITHUB_OUTPUT + if [[ "${BUMP_VERSION_OVERRIDE_OUTCOME}" = "success" ]]; then + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + elif [[ "${BUMP_VERSION_AUTOMATIC_OUTCOME}" = "success" ]]; then + echo "version=${CALCULATE_NEXT_VERSION}" >> "$GITHUB_OUTPUT" fi - name: Commit files - run: git commit -m "Bumped version to ${{ steps.set-final-version-output.outputs.version }}" -a + env: + FINAL_VERSION: ${{ steps.set-final-version-output.outputs.version }} + run: git commit -m "Bumped version to $FINAL_VERSION" -a - name: Push changes run: git push @@ -213,13 +218,15 @@ jobs: with: ref: ${{ inputs.target_ref }} token: ${{ steps.app-token.outputs.token }} + persist-credentials: true + fetch-depth: 0 - name: Check if ${{ needs.setup.outputs.branch }} branch exists env: BRANCH_NAME: ${{ needs.setup.outputs.branch }} run: | - if [[ $(git ls-remote --heads origin $BRANCH_NAME) ]]; then - echo "$BRANCH_NAME already exists! Please delete $BRANCH_NAME before running again." >> $GITHUB_STEP_SUMMARY + if [[ $(git ls-remote --heads origin "$BRANCH_NAME") ]]; then + echo "$BRANCH_NAME already exists! Please delete $BRANCH_NAME before running again." >> "$GITHUB_STEP_SUMMARY" exit 1 fi @@ -227,8 +234,8 @@ jobs: env: BRANCH_NAME: ${{ needs.setup.outputs.branch }} run: | - git switch --quiet --create $BRANCH_NAME - git push --quiet --set-upstream origin $BRANCH_NAME + git switch --quiet --create "$BRANCH_NAME" + git push --quiet --set-upstream origin "$BRANCH_NAME" move_edd_db_scripts: name: Move EDD database scripts diff --git a/.github/workflows/respond.yml b/.github/workflows/respond.yml new file mode 100644 index 0000000000..d940ceee75 --- /dev/null +++ b/.github/workflows/respond.yml @@ -0,0 +1,28 @@ +name: Respond + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +permissions: {} + +jobs: + respond: + name: Respond + uses: bitwarden/gh-actions/.github/workflows/_respond.yml@main + secrets: + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + permissions: + actions: read + contents: write + id-token: write + issues: write + pull-requests: write diff --git a/.github/workflows/review-code.yml b/.github/workflows/review-code.yml new file mode 100644 index 0000000000..0e0597fccf --- /dev/null +++ b/.github/workflows/review-code.yml @@ -0,0 +1,21 @@ +name: Code Review + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + +permissions: {} + +jobs: + review: + name: Review + uses: bitwarden/gh-actions/.github/workflows/_review-code.yml@main + secrets: + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + permissions: + actions: read + contents: read + id-token: write + pull-requests: write diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index 6bbc33299f..fb1c18b158 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -45,9 +45,11 @@ jobs: steps: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Restore tools run: dotnet tool restore @@ -139,26 +141,26 @@ jobs: - name: Print MySQL Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mysql")' + run: 'docker logs "$(docker ps --quiet --filter "name=mysql")"' - name: Print MariaDB Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mariadb")' + run: 'docker logs "$(docker ps --quiet --filter "name=mariadb")"' - name: Print Postgres Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=postgres")' + run: 'docker logs "$(docker ps --quiet --filter "name=postgres")"' - name: Print MSSQL Logs if: failure() - run: 'docker logs $(docker ps --quiet --filter "name=mssql")' + run: 'docker logs "$(docker ps --quiet --filter "name=mssql")"' - name: Report test results uses: dorny/test-reporter@890a17cecf52a379fc869ab770a71657660be727 # v2.1.0 if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }} with: name: Test Results - path: "**/*-test-results.trx" + path: "./**/*-test-results.trx" reporter: dotnet-trx fail-on-error: true @@ -177,9 +179,11 @@ jobs: steps: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 - name: Print environment run: | @@ -193,7 +197,7 @@ jobs: shell: pwsh - name: Upload DACPAC - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: sql.dacpac path: Sql.dacpac @@ -219,7 +223,7 @@ jobs: shell: pwsh - name: Report validation results - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: report.xml path: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4eed6df7ab..36ab8785d5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,9 +28,19 @@ jobs: steps: - name: Check out repo uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4.3.1 + uses: actions/setup-dotnet@d4c94342e560b34958eacfc5d055d21461ed1c5d # v5.0.0 + + - name: Install rust + uses: dtolnay/rust-toolchain@b3b07ba8b418998c39fb20f53e8b695cdcc8de1b # stable + with: + toolchain: stable + + - name: Cache cargo registry + uses: Swatinem/rust-cache@f0deed1e0edfc6a9be95417288c0e1099b1eeec3 # v2.7.7 - name: Print environment run: | diff --git a/.gitignore b/.gitignore index 3b1f40e673..60fc894285 100644 --- a/.gitignore +++ b/.gitignore @@ -215,6 +215,9 @@ bitwarden_license/src/Sso/wwwroot/assets **/**.swp .mono src/Core/MailTemplates/Mjml/out +src/Core/MailTemplates/Mjml/out-hbs +NativeMethods.g.cs +util/RustSdk/rust/target src/Admin/Admin.zip src/Api/Api.zip @@ -231,3 +234,6 @@ bitwarden_license/src/Sso/Sso.zip /identity.json /api.json /api.public.json + +# Serena +.serena/ diff --git a/Directory.Build.props b/Directory.Build.props index 71303d3529..4511202024 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,12 +3,10 @@ net8.0 - 2025.9.2 + 2025.11.0 Bit.$(MSBuildProjectName) enable - false - true annotations enable @@ -32,19 +30,4 @@ 4.18.1 - - - - - - - - - - - <_Parameter1>GitHash - <_Parameter2>$(SourceRevisionId) - - - \ No newline at end of file diff --git a/bitwarden-server.sln b/bitwarden-server.sln index dbc37372a1..6786ad610c 100644 --- a/bitwarden-server.sln +++ b/bitwarden-server.sln @@ -133,8 +133,11 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Seeder", "util\Seeder\Seede EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DbSeederUtility", "util\DbSeederUtility\DbSeederUtility.csproj", "{17A89266-260A-4A03-81AE-C0468C6EE06E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RustSdk", "util\RustSdk\RustSdk.csproj", "{D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}" Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SharedWeb.Test", "test\SharedWeb.Test\SharedWeb.Test.csproj", "{AD59537D-5259-4B7A-948F-0CF58E80B359}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SSO.Test", "bitwarden_license\test\SSO.Test\SSO.Test.csproj", "{7D98784C-C253-43FB-9873-25B65C6250D6}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -339,10 +342,18 @@ Global {17A89266-260A-4A03-81AE-C0468C6EE06E}.Debug|Any CPU.Build.0 = Debug|Any CPU {17A89266-260A-4A03-81AE-C0468C6EE06E}.Release|Any CPU.ActiveCfg = Release|Any CPU {17A89266-260A-4A03-81AE-C0468C6EE06E}.Release|Any CPU.Build.0 = Release|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7}.Release|Any CPU.Build.0 = Release|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Debug|Any CPU.Build.0 = Debug|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Release|Any CPU.ActiveCfg = Release|Any CPU {AD59537D-5259-4B7A-948F-0CF58E80B359}.Release|Any CPU.Build.0 = Release|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7D98784C-C253-43FB-9873-25B65C6250D6}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -397,7 +408,9 @@ Global {3631BA42-6731-4118-A917-DAA43C5032B9} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F} {9A612EBA-1C0E-42B8-982B-62F0EE81000A} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} {17A89266-260A-4A03-81AE-C0468C6EE06E} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} + {D1513D90-E4F5-44A9-9121-5E46E3E4A3F7} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84E} {AD59537D-5259-4B7A-948F-0CF58E80B359} = {DD5BD056-4AAE-43EF-BBD2-0B569B8DA84F} + {7D98784C-C253-43FB-9873-25B65C6250D6} = {287CFF34-BBDB-4BC4-AF88-1E19A5A4679B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {E01CBF68-2E20-425F-9EDB-E0A6510CA92F} diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs index 9ade2d660a..994b305349 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Providers/RemoveOrganizationFromProviderCommand.cs @@ -148,22 +148,30 @@ public class RemoveOrganizationFromProviderCommand : IRemoveOrganizationFromProv } else if (organization.IsStripeEnabled()) { - var subscription = await _stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId); + var subscription = await _stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["customer"] + }); + if (subscription.Status is StripeConstants.SubscriptionStatus.Canceled or StripeConstants.SubscriptionStatus.IncompleteExpired) { return; } - await _stripeAdapter.CustomerUpdateAsync(organization.GatewayCustomerId, new CustomerUpdateOptions + await _stripeAdapter.CustomerUpdateAsync(subscription.CustomerId, new CustomerUpdateOptions { - Coupon = string.Empty, Email = organization.BillingEmail }); + if (subscription.Customer.Discount?.Coupon != null) + { + await _stripeAdapter.CustomerDeleteDiscountAsync(subscription.CustomerId); + } + await _stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.SendInvoice, - DaysUntilDue = 30 + DaysUntilDue = 30, }); await _subscriberService.RemovePaymentSource(organization); diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index aa19ad5382..89ef251fd6 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -12,7 +12,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; @@ -35,8 +35,9 @@ public class ProviderService : IProviderService { private static readonly PlanType[] _resellerDisallowedOrganizationTypes = [ PlanType.Free, - PlanType.FamiliesAnnually, - PlanType.FamiliesAnnually2019 + PlanType.FamiliesAnnually2025, + PlanType.FamiliesAnnually2019, + PlanType.FamiliesAnnually ]; private readonly IDataProtector _dataProtector; @@ -90,7 +91,7 @@ public class ProviderService : IProviderService _providerClientOrganizationSignUpCommand = providerClientOrganizationSignUpCommand; } - public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) + public async Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) { var owner = await _userService.GetUserByIdAsync(ownerUserId); if (owner == null) @@ -115,21 +116,7 @@ public class ProviderService : IProviderService throw new BadRequestException("Invalid owner."); } - if (taxInfo == null || string.IsNullOrEmpty(taxInfo.BillingAddressCountry) || string.IsNullOrEmpty(taxInfo.BillingAddressPostalCode)) - { - throw new BadRequestException("Both address and postal code are required to set up your provider."); - } - - if (tokenizedPaymentSource is not - { - Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, - Token: not null and not "" - }) - { - throw new BadRequestException("A payment method is required to set up your provider."); - } - - var customer = await _providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var customer = await _providerBillingService.SetupCustomer(provider, paymentMethod, billingAddress); provider.GatewayCustomerId = customer.Id; var subscription = await _providerBillingService.SetupSubscription(provider); provider.GatewaySubscriptionId = subscription.Id; diff --git a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs index 398674c7b6..e352297f1e 100644 --- a/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs +++ b/bitwarden_license/src/Commercial.Core/Billing/Providers/Services/ProviderBillingService.cs @@ -14,6 +14,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; @@ -21,10 +22,8 @@ using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -38,6 +37,9 @@ using Subscription = Stripe.Subscription; namespace Bit.Commercial.Core.Billing.Providers.Services; +using static Constants; +using static StripeConstants; + public class ProviderBillingService( IBraintreeGateway braintreeGateway, IEventService eventService, @@ -51,8 +53,7 @@ public class ProviderBillingService( IProviderUserRepository providerUserRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService, - ITaxService taxService) + ISubscriberService subscriberService) : IProviderBillingService { public async Task AddExistingOrganization( @@ -61,10 +62,7 @@ public class ProviderBillingService( string key) { await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, - new SubscriptionUpdateOptions - { - CancelAtPeriodEnd = false - }); + new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); var subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, @@ -83,7 +81,7 @@ public class ProviderBillingService( var wasTrialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; - if (!wasTrialing && subscription.LatestInvoice.Status == StripeConstants.InvoiceStatus.Draft) + if (!wasTrialing && subscription.LatestInvoice.Status == InvoiceStatus.Draft) { await stripeAdapter.InvoiceFinalizeInvoiceAsync(subscription.LatestInvoiceId, new InvoiceFinalizeOptions { AutoAdvance = true }); @@ -184,16 +182,8 @@ public class ProviderBillingService( { Items = [ - new SubscriptionItemOptions - { - Price = newPriceId, - Quantity = oldSubscriptionItem!.Quantity - }, - new SubscriptionItemOptions - { - Id = oldSubscriptionItem.Id, - Deleted = true - } + new SubscriptionItemOptions { Price = newPriceId, Quantity = oldSubscriptionItem!.Quantity }, + new SubscriptionItemOptions { Id = oldSubscriptionItem.Id, Deleted = true } ] }; @@ -202,7 +192,8 @@ public class ProviderBillingService( // Refactor later to ?ChangeClientPlanCommand? (ProviderPlanId, ProviderId, OrganizationId) // 1. Retrieve PlanType and PlanName for ProviderPlan // 2. Assign PlanType & PlanName to Organization - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); + var providerOrganizations = + await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerPlan.ProviderId); var newPlan = await pricingClient.GetPlanOrThrow(newPlanType); @@ -213,6 +204,7 @@ public class ProviderBillingService( { throw new ConflictException($"Organization '{providerOrganization.Id}' not found."); } + organization.PlanType = newPlanType; organization.Plan = newPlan.Name; await organizationRepository.ReplaceAsync(organization); @@ -228,15 +220,15 @@ public class ProviderBillingService( if (!string.IsNullOrEmpty(organization.GatewayCustomerId)) { - logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, nameof(organization.GatewayCustomerId)); + logger.LogWarning("Client organization ({ID}) already has a populated {FieldName}", organization.Id, + nameof(organization.GatewayCustomerId)); return; } - var providerCustomer = await subscriberService.GetCustomerOrThrow(provider, new CustomerGetOptions - { - Expand = ["tax", "tax_ids"] - }); + var providerCustomer = + await subscriberService.GetCustomerOrThrow(provider, + new CustomerGetOptions { Expand = ["tax", "tax_ids"] }); var providerTaxId = providerCustomer.TaxIds.FirstOrDefault(); @@ -269,23 +261,18 @@ public class ProviderBillingService( } ] }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - }, - TaxIdData = providerTaxId == null ? null : - [ - new CustomerTaxIdDataOptions - { - Type = providerTaxId.Type, - Value = providerTaxId.Value - } - ] + Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, + TaxIdData = providerTaxId == null + ? null + : + [ + new CustomerTaxIdDataOptions { Type = providerTaxId.Type, Value = providerTaxId.Value } + ] }; - if (providerCustomer.Address is not { Country: Constants.CountryAbbreviations.UnitedStates }) + if (providerCustomer.Address is not { Country: CountryAbbreviations.UnitedStates }) { - customerCreateOptions.TaxExempt = StripeConstants.TaxExempt.Reverse; + customerCreateOptions.TaxExempt = TaxExempt.Reverse; } var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); @@ -347,9 +334,9 @@ public class ProviderBillingService( .Where(pair => pair.subscription is { Status: - StripeConstants.SubscriptionStatus.Active or - StripeConstants.SubscriptionStatus.Trialing or - StripeConstants.SubscriptionStatus.PastDue + SubscriptionStatus.Active or + SubscriptionStatus.Trialing or + SubscriptionStatus.PastDue }).ToList(); if (active.Count == 0) @@ -474,36 +461,25 @@ public class ProviderBillingService( // Below the limit to above the limit (currentlyAssignedSeatTotal <= seatMinimum && newlyAssignedSeatTotal > seatMinimum) || // Above the limit to further above the limit - (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > currentlyAssignedSeatTotal); + (currentlyAssignedSeatTotal > seatMinimum && newlyAssignedSeatTotal > seatMinimum && + newlyAssignedSeatTotal > currentlyAssignedSeatTotal); } public async Task SetupCustomer( Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress) { - ArgumentNullException.ThrowIfNull(tokenizedPaymentSource); - - if (taxInfo is not - { - BillingAddressCountry: not null and not "", - BillingAddressPostalCode: not null and not "" - }) - { - logger.LogError("Cannot create customer for provider ({ProviderID}) without both a country and postal code", provider.Id); - throw new BillingException(); - } - var options = new CustomerCreateOptions { Address = new AddressOptions { - Country = taxInfo.BillingAddressCountry, - PostalCode = taxInfo.BillingAddressPostalCode, - Line1 = taxInfo.BillingAddressLine1, - Line2 = taxInfo.BillingAddressLine2, - City = taxInfo.BillingAddressCity, - State = taxInfo.BillingAddressState + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode, + Line1 = billingAddress.Line1, + Line2 = billingAddress.Line2, + City = billingAddress.City, + State = billingAddress.State }, Description = provider.DisplayBusinessName(), Email = provider.BillingEmail, @@ -520,93 +496,61 @@ public class ProviderBillingService( } ] }, - Metadata = new Dictionary - { - { "region", globalSettings.BaseServiceUri.CloudRegion } - } + Metadata = new Dictionary { { "region", globalSettings.BaseServiceUri.CloudRegion } }, + TaxExempt = billingAddress.Country != CountryAbbreviations.UnitedStates ? TaxExempt.Reverse : TaxExempt.None }; - if (taxInfo.BillingAddressCountry is not Constants.CountryAbbreviations.UnitedStates) + if (billingAddress.TaxId != null) { - options.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - if (!string.IsNullOrEmpty(taxInfo.TaxIdNumber)) - { - var taxIdType = taxService.GetStripeTaxCode( - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - - if (taxIdType == null) - { - logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - options.TaxIdData = [ - new CustomerTaxIdDataOptions { Type = taxIdType, Value = taxInfo.TaxIdNumber } + new CustomerTaxIdDataOptions { Type = billingAddress.TaxId.Code, Value = billingAddress.TaxId.Value } ]; - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) + if (billingAddress.TaxId.Code == TaxIdType.SpanishNIF) { options.TaxIdData.Add(new CustomerTaxIdDataOptions { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{taxInfo.TaxIdNumber}" + Type = TaxIdType.EUVAT, + Value = $"ES{billingAddress.TaxId.Value}" }); } } - if (!string.IsNullOrEmpty(provider.DiscountId)) - { - options.Coupon = provider.DiscountId; - } - var braintreeCustomerId = ""; - if (tokenizedPaymentSource is not - { - Type: PaymentMethodType.BankAccount or PaymentMethodType.Card or PaymentMethodType.PayPal, - Token: not null and not "" - }) - { - logger.LogError("Cannot create customer for provider ({ProviderID}) with invalid payment method", provider.Id); - throw new BillingException(); - } - - var (type, token) = tokenizedPaymentSource; - // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (type) + switch (paymentMethod.Type) { - case PaymentMethodType.BankAccount: + case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = token })) + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions + { + PaymentMethod = paymentMethod.Token + })) .FirstOrDefault(); if (setupIntent == null) { - logger.LogError("Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", provider.Id); + logger.LogError( + "Cannot create customer for provider ({ProviderID}) without a setup intent for their bank account", + provider.Id); throw new BillingException(); } await setupIntentCache.Set(provider.Id, setupIntent.Id); break; } - case PaymentMethodType.Card: + case TokenizablePaymentMethodType.Card: { - options.PaymentMethod = token; - options.InvoiceSettings.DefaultPaymentMethod = token; + options.PaymentMethod = paymentMethod.Token; + options.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token; break; } - case PaymentMethodType.PayPal: + case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(provider, paymentMethod.Token); options.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } @@ -616,8 +560,7 @@ public class ProviderBillingService( { return await stripeAdapter.CustomerCreateAsync(options); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == - StripeConstants.ErrorCodes.TaxIdInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == ErrorCodes.TaxIdInvalid) { await Revert(); throw new BadRequestException( @@ -632,9 +575,9 @@ public class ProviderBillingService( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (tokenizedPaymentSource.Type) + switch (paymentMethod.Type) { - case PaymentMethodType.BankAccount: + case TokenizablePaymentMethodType.BankAccount: { var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); await stripeAdapter.SetupIntentCancel(setupIntentId, @@ -642,7 +585,7 @@ public class ProviderBillingService( await setupIntentCache.RemoveSetupIntentForSubscriber(provider.Id); break; } - case PaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): + case TokenizablePaymentMethodType.PayPal when !string.IsNullOrEmpty(braintreeCustomerId): { await braintreeGateway.Customer.DeleteAsync(braintreeCustomerId); break; @@ -661,9 +604,10 @@ public class ProviderBillingService( var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - if (providerPlans == null || providerPlans.Count == 0) + if (providerPlans.Count == 0) { - logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", provider.Id); + logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured plans", + provider.Id); throw new BillingException(); } @@ -676,7 +620,9 @@ public class ProviderBillingService( if (!providerPlan.IsConfigured()) { - logger.LogError("Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", provider.Id, plan.Name); + logger.LogError( + "Cannot start subscription for provider ({ProviderID}) that has no configured {ProviderName} plan", + provider.Id, plan.Name); throw new BillingException(); } @@ -692,16 +638,14 @@ public class ProviderBillingService( var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(provider.Id); var setupIntent = !string.IsNullOrEmpty(setupIntentId) - ? await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions - { - Expand = ["payment_method"] - }) + ? await stripeAdapter.SetupIntentGet(setupIntentId, + new SetupIntentGetOptions { Expand = ["payment_method"] }) : null; var usePaymentMethod = !string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) || - (customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true) || - (setupIntent?.IsUnverifiedBankAccount() == true); + customer.Metadata?.ContainsKey(BraintreeCustomerIdKey) == true || + setupIntent?.IsUnverifiedBankAccount() == true; int? trialPeriodDays = provider.Type switch { @@ -712,30 +656,28 @@ public class ProviderBillingService( var subscriptionCreateOptions = new SubscriptionCreateOptions { - CollectionMethod = usePaymentMethod ? - StripeConstants.CollectionMethod.ChargeAutomatically : StripeConstants.CollectionMethod.SendInvoice, + CollectionMethod = + usePaymentMethod + ? CollectionMethod.ChargeAutomatically + : CollectionMethod.SendInvoice, Customer = customer.Id, DaysUntilDue = usePaymentMethod ? null : 30, + Discounts = !string.IsNullOrEmpty(provider.DiscountId) ? [new SubscriptionDiscountOptions { Coupon = provider.DiscountId }] : null, Items = subscriptionItemOptionsList, - Metadata = new Dictionary - { - { "providerId", provider.Id.ToString() } - }, + Metadata = new Dictionary { { "providerId", provider.Id.ToString() } }, OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, - TrialPeriodDays = trialPeriodDays + ProrationBehavior = ProrationBehavior.CreateProrations, + TrialPeriodDays = trialPeriodDays, + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } }; - - subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; - try { var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); if (subscription is { - Status: StripeConstants.SubscriptionStatus.Active or StripeConstants.SubscriptionStatus.Trialing + Status: SubscriptionStatus.Active or SubscriptionStatus.Trialing }) { return subscription; @@ -749,9 +691,11 @@ public class ProviderBillingService( throw new BillingException(); } - catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) + catch (StripeException stripeException) when (stripeException.StripeError?.Code == + ErrorCodes.CustomerTaxLocationInvalid) { - throw new BadRequestException("Your location wasn't recognized. Please ensure your country and postal code are valid."); + throw new BadRequestException( + "Your location wasn't recognized. Please ensure your country and postal code are valid."); } } @@ -765,7 +709,7 @@ public class ProviderBillingService( subscriberService.UpdateTaxInformation(provider, taxInformation)); await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, - new SubscriptionUpdateOptions { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically }); + new SubscriptionUpdateOptions { CollectionMethod = CollectionMethod.ChargeAutomatically }); } public async Task UpdateSeatMinimums(UpdateProviderSeatMinimumsCommand command) @@ -865,13 +809,9 @@ public class ProviderBillingService( await stripeAdapter.SubscriptionUpdateAsync(provider.GatewaySubscriptionId, new SubscriptionUpdateOptions { - Items = [ - new SubscriptionItemOptions - { - Id = item.Id, - Price = priceId, - Quantity = newlySubscribedSeats - } + Items = + [ + new SubscriptionItemOptions { Id = item.Id, Price = priceId, Quantity = newlySubscribedSeats } ] }); @@ -894,7 +834,8 @@ public class ProviderBillingService( var plan = await pricingClient.GetPlanOrThrow(planType); return providerOrganizations - .Where(providerOrganization => providerOrganization.Plan == plan.Name && providerOrganization.Status == OrganizationStatusType.Managed) + .Where(providerOrganization => providerOrganization.Plan == plan.Name && + providerOrganization.Status == OrganizationStatusType.Managed) .Sum(providerOrganization => providerOrganization.Seats ?? 0); } diff --git a/bitwarden_license/src/Commercial.Core/Commercial.Core.csproj b/bitwarden_license/src/Commercial.Core/Commercial.Core.csproj index 57babb4043..9209917d1e 100644 --- a/bitwarden_license/src/Commercial.Core/Commercial.Core.csproj +++ b/bitwarden_license/src/Commercial.Core/Commercial.Core.csproj @@ -5,7 +5,7 @@ - + diff --git a/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs b/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs index 12c7f679bd..b73b358925 100644 --- a/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs +++ b/bitwarden_license/src/Commercial.Core/SecretsManager/Commands/ServiceAccounts/CreateServiceAccountCommand.cs @@ -1,10 +1,13 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Commands.ServiceAccounts.Interfaces; using Bit.Core.SecretsManager.Entities; using Bit.Core.SecretsManager.Repositories; +using Bit.Core.Services; namespace Bit.Commercial.Core.SecretsManager.Commands.ServiceAccounts; @@ -13,15 +16,21 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand private readonly IAccessPolicyRepository _accessPolicyRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly IEventService _eventService; + private readonly ICurrentContext _currentContext; public CreateServiceAccountCommand( IAccessPolicyRepository accessPolicyRepository, IOrganizationUserRepository organizationUserRepository, - IServiceAccountRepository serviceAccountRepository) + IServiceAccountRepository serviceAccountRepository, + IEventService eventService, + ICurrentContext currentContext) { _accessPolicyRepository = accessPolicyRepository; _organizationUserRepository = organizationUserRepository; _serviceAccountRepository = serviceAccountRepository; + _eventService = eventService; + _currentContext = currentContext; } public async Task CreateAsync(ServiceAccount serviceAccount, Guid userId) @@ -38,6 +47,7 @@ public class CreateServiceAccountCommand : ICreateServiceAccountCommand Write = true, }; await _accessPolicyRepository.CreateManyAsync(new List { accessPolicy }); + await _eventService.LogServiceAccountPeopleEventAsync(user.Id, accessPolicy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType); return createdServiceAccount; } } diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index 98a581e8ca..a0842daa34 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Security.Claims; +using System.Security.Claims; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; @@ -57,6 +54,7 @@ public class AccountController : Controller private readonly IDataProtectorTokenFactory _dataProtector; private readonly IOrganizationDomainRepository _organizationDomainRepository; private readonly IRegisterUserCommand _registerUserCommand; + private readonly IFeatureService _featureService; public AccountController( IAuthenticationSchemeProvider schemeProvider, @@ -77,7 +75,8 @@ public class AccountController : Controller Core.Services.IEventService eventService, IDataProtectorTokenFactory dataProtector, IOrganizationDomainRepository organizationDomainRepository, - IRegisterUserCommand registerUserCommand) + IRegisterUserCommand registerUserCommand, + IFeatureService featureService) { _schemeProvider = schemeProvider; _clientStore = clientStore; @@ -98,10 +97,11 @@ public class AccountController : Controller _dataProtector = dataProtector; _organizationDomainRepository = organizationDomainRepository; _registerUserCommand = registerUserCommand; + _featureService = featureService; } [HttpGet] - public async Task PreValidate(string domainHint) + public async Task PreValidateAsync(string domainHint) { try { @@ -160,10 +160,12 @@ public class AccountController : Controller } [HttpGet] - public async Task Login(string returnUrl) + public async Task LoginAsync(string returnUrl) { var context = await _interaction.GetAuthorizationContextAsync(returnUrl); + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable if (!context.Parameters.AllKeys.Contains("domain_hint") || string.IsNullOrWhiteSpace(context.Parameters["domain_hint"])) { @@ -179,6 +181,7 @@ public class AccountController : Controller var domainHint = context.Parameters["domain_hint"]; var organization = await _organizationRepository.GetByIdentifierAsync(domainHint); +#nullable restore if (organization == null) { @@ -235,37 +238,73 @@ public class AccountController : Controller [HttpGet] public async Task ExternalCallback() { + // Feature flag (PM-24579): Prevent SSO on existing non-compliant users. + var preventOrgUserLoginIfStatusInvalid = + _featureService.IsEnabled(FeatureFlagKeys.PM24579_PreventSsoOnExistingNonCompliantUsers); + // Read external identity from the temporary cookie var result = await HttpContext.AuthenticateAsync( AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); - if (result?.Succeeded != true) - { - throw new Exception(_i18nService.T("ExternalAuthenticationError")); - } - // Debugging - var externalClaims = result.Principal.Claims.Select(c => $"{c.Type}: {c.Value}"); - _logger.LogDebug("External claims: {@claims}", externalClaims); + if (preventOrgUserLoginIfStatusInvalid) + { + if (!result.Succeeded) + { + throw new Exception(_i18nService.T("ExternalAuthenticationError")); + } + } + else + { + if (result?.Succeeded != true) + { + throw new Exception(_i18nService.T("ExternalAuthenticationError")); + } + } // See if the user has logged in with this SSO provider before and has already been provisioned. // This is signified by the user existing in the User table and the SSOUser table for the SSO provider they're using. - var (user, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + var (possibleSsoLinkedUser, provider, providerUserId, claims, ssoConfigData) = await FindUserFromExternalProviderAsync(result); + + // We will look these up as required (lazy resolution) to avoid multiple DB hits. + Organization? organization = null; + OrganizationUser? orgUser = null; // The user has not authenticated with this SSO provider before. // They could have an existing Bitwarden account in the User table though. - if (user == null) + if (possibleSsoLinkedUser == null) { + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable // If we're manually linking to SSO, the user's external identifier will be passed as query string parameter. - var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") ? - result.Properties.Items["user_identifier"] : null; - user = await AutoProvisionUserAsync(provider, providerUserId, claims, userIdentifier, ssoConfigData); + var userIdentifier = result.Properties.Items.Keys.Contains("user_identifier") + ? result.Properties.Items["user_identifier"] + : null; + + var (resolvedUser, foundOrganization, foundOrCreatedOrgUser) = + await CreateUserAndOrgUserConditionallyAsync( + provider, + providerUserId, + claims, + userIdentifier, + ssoConfigData); +#nullable restore + + possibleSsoLinkedUser = resolvedUser; + + if (preventOrgUserLoginIfStatusInvalid) + { + organization = foundOrganization; + orgUser = foundOrCreatedOrgUser; + } } - // Either the user already authenticated with the SSO provider, or we've just provisioned them. - // Either way, we have associated the SSO login with a Bitwarden user. - // We will now sign the Bitwarden user in. - if (user != null) + if (preventOrgUserLoginIfStatusInvalid) { + User resolvedSsoLinkedUser = possibleSsoLinkedUser + ?? throw new Exception(_i18nService.T("UserShouldBeFound")); + + await PreventOrgUserLoginIfStatusInvalidAsync(organization, provider, orgUser, resolvedSsoLinkedUser); + // This allows us to collect any additional claims or properties // for the specific protocols used and store them in the local auth cookie. // this is typically used to store data needed for signout from those protocols. @@ -278,19 +317,52 @@ public class AccountController : Controller ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); // Issue authentication cookie for user - await HttpContext.SignInAsync(new IdentityServerUser(user.Id.ToString()) + await HttpContext.SignInAsync( + new IdentityServerUser(resolvedSsoLinkedUser.Id.ToString()) + { + DisplayName = resolvedSsoLinkedUser.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } + else + { + // PM-24579: remove this else block with feature flag removal. + // Either the user already authenticated with the SSO provider, or we've just provisioned them. + // Either way, we have associated the SSO login with a Bitwarden user. + // We will now sign the Bitwarden user in. + if (possibleSsoLinkedUser != null) { - DisplayName = user.Email, - IdentityProvider = provider, - AdditionalClaims = additionalLocalClaims.ToArray() - }, localSignInProps); + // This allows us to collect any additional claims or properties + // for the specific protocols used and store them in the local auth cookie. + // this is typically used to store data needed for signout from those protocols. + var additionalLocalClaims = new List(); + var localSignInProps = new AuthenticationProperties + { + IsPersistent = true, + ExpiresUtc = DateTimeOffset.UtcNow.AddMinutes(1) + }; + ProcessLoginCallback(result, additionalLocalClaims, localSignInProps); + + // Issue authentication cookie for user + await HttpContext.SignInAsync( + new IdentityServerUser(possibleSsoLinkedUser.Id.ToString()) + { + DisplayName = possibleSsoLinkedUser.Email, + IdentityProvider = provider, + AdditionalClaims = additionalLocalClaims.ToArray() + }, localSignInProps); + } } // Delete temporary cookie used during external authentication await HttpContext.SignOutAsync(AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable // Retrieve return URL var returnUrl = result.Properties.Items["return_url"] ?? "~/"; +#nullable restore // Check if external login is in the context of an OIDC request var context = await _interaction.GetAuthorizationContextAsync(returnUrl); @@ -309,8 +381,10 @@ public class AccountController : Controller return Redirect(returnUrl); } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable [HttpGet] - public async Task Logout(string logoutId) + public async Task LogoutAsync(string logoutId) { // Build a model so the logged out page knows what to display var (updatedLogoutId, redirectUri, externalAuthenticationScheme) = await GetLoggedOutDataAsync(logoutId); @@ -333,6 +407,7 @@ public class AccountController : Controller // This triggers a redirect to the external provider for sign-out return SignOut(new AuthenticationProperties { RedirectUri = url }, externalAuthenticationScheme); } + if (redirectUri != null) { return View("Redirect", new RedirectViewModel { RedirectUrl = redirectUri }); @@ -342,14 +417,22 @@ public class AccountController : Controller return Redirect("~/"); } } +#nullable restore /// /// Attempts to map the external identity to a Bitwarden user, through the SsoUser table, which holds the `externalId`. /// The claims on the external identity are used to determine an `externalId`, and that is used to find the appropriate `SsoUser` and `User` records. /// - private async Task<(User user, string provider, string providerUserId, IEnumerable claims, SsoConfigurationData config)> - FindUserFromExternalProviderAsync(AuthenticateResult result) + private async Task<( + User? possibleSsoUser, + string provider, + string providerUserId, + IEnumerable claims, + SsoConfigurationData config + )> FindUserFromExternalProviderAsync(AuthenticateResult result) { + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable var provider = result.Properties.Items["scheme"]; var orgId = new Guid(provider); var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(orgId); @@ -374,9 +457,10 @@ public class AccountController : Controller // Ensure the NameIdentifier used is not a transient name ID, if so, we need a different attribute // for the user identifier. static bool nameIdIsNotTransient(Claim c) => c.Type == ClaimTypes.NameIdentifier - && (c.Properties == null - || !c.Properties.TryGetValue(SamlPropertyKeys.ClaimFormat, out var claimFormat) - || claimFormat != SamlNameIdFormats.Transient); + && (c.Properties == null + || !c.Properties.TryGetValue(SamlPropertyKeys.ClaimFormat, + out var claimFormat) + || claimFormat != SamlNameIdFormats.Transient); // Try to determine the unique id of the external user (issued by the provider) // the most common claim type for that are the sub claim and the NameIdentifier @@ -391,6 +475,7 @@ public class AccountController : Controller externalUser.FindFirst("upn") ?? externalUser.FindFirst("eppn") ?? throw new Exception(_i18nService.T("UnknownUserId")); +#nullable restore // Remove the user id claim so we don't include it as an extra claim if/when we provision the user var claims = externalUser.Claims.ToList(); @@ -399,13 +484,15 @@ public class AccountController : Controller // find external user var providerUserId = userIdClaim.Value; - var user = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); + var possibleSsoUser = await _userRepository.GetBySsoUserAsync(providerUserId, orgId); - return (user, provider, providerUserId, claims, ssoConfigData); + return (possibleSsoUser, provider, providerUserId, claims, ssoConfigData); } /// - /// Provision an SSO-linked Bitwarden user. + /// This function seeks to set up the org user record or create a new user record based on the conditions + /// below. + /// /// This handles three different scenarios: /// 1. Creating an SsoUser link for an existing User and OrganizationUser /// - User is a member of the organization, but hasn't authenticated with the org's SSO provider before. @@ -418,77 +505,100 @@ public class AccountController : Controller /// The external identity provider's user identifier. /// The claims from the external IdP. /// The user identifier used for manual SSO linking. - /// The SSO configuration for the organization. - /// The User to sign in. + /// The SSO configuration for the organization. + /// Guaranteed to return the user to sign in as well as the found organization and org user. /// An exception if the user cannot be provisioned as requested. - private async Task AutoProvisionUserAsync(string provider, string providerUserId, - IEnumerable claims, string userIdentifier, SsoConfigurationData config) + private async Task<(User resolvedUser, Organization foundOrganization, OrganizationUser foundOrgUser)> CreateUserAndOrgUserConditionallyAsync( + string provider, + string providerUserId, + IEnumerable claims, + string userIdentifier, + SsoConfigurationData ssoConfigData + ) { - var name = GetName(claims, config.GetAdditionalNameClaimTypes()); - var email = GetEmailAddress(claims, config.GetAdditionalEmailClaimTypes()); - if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) - { - email = providerUserId; - } + // Try to get the email from the claims as we don't know if we have a user record yet. + var name = GetName(claims, ssoConfigData.GetAdditionalNameClaimTypes()); + var email = TryGetEmailAddress(claims, ssoConfigData, providerUserId); - if (!Guid.TryParse(provider, out var orgId)) - { - // TODO: support non-org (server-wide) SSO in the future? - throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); - } - - User existingUser = null; + User? possibleExistingUser; if (string.IsNullOrWhiteSpace(userIdentifier)) { if (string.IsNullOrWhiteSpace(email)) { throw new Exception(_i18nService.T("CannotFindEmailClaim")); } - existingUser = await _userRepository.GetByEmailAsync(email); + + possibleExistingUser = await _userRepository.GetByEmailAsync(email); } else { - existingUser = await GetUserFromManualLinkingData(userIdentifier); + possibleExistingUser = await GetUserFromManualLinkingDataAsync(userIdentifier); } - // Try to find the OrganizationUser if it exists. - var (organization, orgUser) = await FindOrganizationUser(existingUser, email, orgId); + // Find the org (we error if we can't find an org because no org is not valid) + var organization = await GetOrganizationByProviderAsync(provider); + + // Try to find an org user (null org user possible and valid here) + var possibleOrgUser = await GetOrganizationUserByUserAndOrgIdOrEmailAsync(possibleExistingUser, organization.Id, email); //---------------------------------------------------- // Scenario 1: We've found the user in the User table //---------------------------------------------------- - if (existingUser != null) + if (possibleExistingUser != null) { - if (existingUser.UsesKeyConnector && - (orgUser == null || orgUser.Status == OrganizationUserStatusType.Invited)) + User guaranteedExistingUser = possibleExistingUser; + + if (guaranteedExistingUser.UsesKeyConnector && + (possibleOrgUser == null || possibleOrgUser.Status == OrganizationUserStatusType.Invited)) { throw new Exception(_i18nService.T("UserAlreadyExistsKeyConnector")); } - // If the user already exists in Bitwarden, we require that the user already be in the org, - // and that they are either Accepted or Confirmed. - if (orgUser == null) + OrganizationUser guaranteedOrgUser = possibleOrgUser ?? throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + + /* + * ---------------------------------------------------- + * Critical Code Check Here + * + * We want to ensure a user is not in the invited state + * explicitly. User's in the invited state should not + * be able to authenticate via SSO. + * + * See internal doc called "Added Context for SSO Login + * Flows" for further details. + * ---------------------------------------------------- + */ + if (guaranteedOrgUser.Status == OrganizationUserStatusType.Invited) { - // Org User is not created - no invite has been sent - throw new Exception(_i18nService.T("UserAlreadyExistsInviteProcess")); + // Org User is invited – must accept via email first + throw new Exception( + _i18nService.T("AcceptInviteBeforeUsingSSO", organization.DisplayName())); } - EnsureOrgUserStatusAllowed(orgUser.Status, organization.DisplayName(), - allowedStatuses: [OrganizationUserStatusType.Accepted, OrganizationUserStatusType.Confirmed]); - + // If the user already exists in Bitwarden, we require that the user already be in the org, + // and that they are either Accepted or Confirmed. + EnforceAllowedOrgUserStatus( + guaranteedOrgUser.Status, + allowedStatuses: [ + OrganizationUserStatusType.Accepted, + OrganizationUserStatusType.Confirmed + ], + organization.DisplayName()); // Since we're in the auto-provisioning logic, this means that the user exists, but they have not // authenticated with the org's SSO provider before now (otherwise we wouldn't be auto-provisioning them). // We've verified that the user is Accepted or Confnirmed, so we can create an SsoUser link and proceed // with authentication. - await CreateSsoUserRecord(providerUserId, existingUser.Id, orgId, orgUser); - return existingUser; + await CreateSsoUserRecordAsync(providerUserId, guaranteedExistingUser.Id, organization.Id, guaranteedOrgUser); + + return (guaranteedExistingUser, organization, guaranteedOrgUser); } // Before any user creation - if Org User doesn't exist at this point - make sure there are enough seats to add one - if (orgUser == null && organization.Seats.HasValue) + if (possibleOrgUser == null && organization.Seats.HasValue) { - var occupiedSeats = await _organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + var occupiedSeats = + await _organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); var initialSeatCount = organization.Seats.Value; var availableSeats = initialSeatCount - occupiedSeats.Total; if (availableSeats < 1) @@ -506,8 +616,10 @@ public class AccountController : Controller { if (organization.Seats.Value != initialSeatCount) { - await _organizationService.AdjustSeatsAsync(orgId, initialSeatCount - organization.Seats.Value); + await _organizationService.AdjustSeatsAsync(organization.Id, + initialSeatCount - organization.Seats.Value); } + _logger.LogInformation(e, "SSO auto provisioning failed"); throw new Exception(_i18nService.T("NoSeatsAvailable", organization.DisplayName())); } @@ -515,40 +627,46 @@ public class AccountController : Controller } // If the email domain is verified, we can mark the email as verified + if (string.IsNullOrWhiteSpace(email)) + { + throw new Exception(_i18nService.T("CannotFindEmailClaim")); + } + var emailVerified = false; var emailDomain = CoreHelpers.GetEmailDomain(email); if (!string.IsNullOrWhiteSpace(emailDomain)) { - var organizationDomain = await _organizationDomainRepository.GetDomainByOrgIdAndDomainNameAsync(orgId, emailDomain); + var organizationDomain = + await _organizationDomainRepository.GetDomainByOrgIdAndDomainNameAsync(organization.Id, emailDomain); emailVerified = organizationDomain?.VerifiedDate.HasValue ?? false; } //-------------------------------------------------- // Scenarios 2 and 3: We need to register a new user //-------------------------------------------------- - var user = new User + var newUser = new User { Name = name, Email = email, EmailVerified = emailVerified, ApiKey = CoreHelpers.SecureRandomString(30) }; - await _registerUserCommand.RegisterUser(user); + await _registerUserCommand.RegisterUser(newUser); // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email var twoFactorPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.TwoFactorAuthentication); + await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.TwoFactorAuthentication); if (twoFactorPolicy != null && twoFactorPolicy.Enabled) { - user.SetTwoFactorProviders(new Dictionary + newUser.SetTwoFactorProviders(new Dictionary { [TwoFactorProviderType.Email] = new TwoFactorProvider { - MetaData = new Dictionary { ["Email"] = user.Email.ToLowerInvariant() }, + MetaData = new Dictionary { ["Email"] = newUser.Email.ToLowerInvariant() }, Enabled = true } }); - await _userService.UpdateTwoFactorProviderAsync(user, TwoFactorProviderType.Email); + await _userService.UpdateTwoFactorProviderAsync(newUser, TwoFactorProviderType.Email); } //----------------------------------------------------------------- @@ -556,17 +674,18 @@ public class AccountController : Controller // This means that an invitation was not sent for this user and we // need to establish their invited status now. //----------------------------------------------------------------- - if (orgUser == null) + if (possibleOrgUser == null) { - orgUser = new OrganizationUser + possibleOrgUser = new OrganizationUser { - OrganizationId = orgId, - UserId = user.Id, + OrganizationId = organization.Id, + UserId = newUser.Id, Type = OrganizationUserType.User, Status = OrganizationUserStatusType.Invited }; - await _organizationUserRepository.CreateAsync(orgUser); + await _organizationUserRepository.CreateAsync(possibleOrgUser); } + //----------------------------------------------------------------- // Scenario 3: There is already an existing OrganizationUser // That was established through an invitation. We just need to @@ -574,24 +693,68 @@ public class AccountController : Controller //----------------------------------------------------------------- else { - orgUser.UserId = user.Id; - await _organizationUserRepository.ReplaceAsync(orgUser); + possibleOrgUser.UserId = newUser.Id; + await _organizationUserRepository.ReplaceAsync(possibleOrgUser); } // Create the SsoUser record to link the user to the SSO provider. - await CreateSsoUserRecord(providerUserId, user.Id, orgId, orgUser); + await CreateSsoUserRecordAsync(providerUserId, newUser.Id, organization.Id, possibleOrgUser); - return user; + return (newUser, organization, possibleOrgUser); } - private async Task GetUserFromManualLinkingData(string userIdentifier) + /// + /// Validates an organization user is allowed to log in via SSO and blocks invalid statuses. + /// Lazily resolves the organization and organization user if not provided. + /// + /// The target organization; if null, resolved from provider. + /// The SSO scheme provider value (organization id as a GUID string). + /// The organization-user record; if null, looked up by user/org or user email for invited users. + /// The user attempting to sign in (existing or newly provisioned). + /// Thrown if the organization cannot be resolved from provider; + /// the organization user cannot be found; or the organization user status is not allowed. + private async Task PreventOrgUserLoginIfStatusInvalidAsync( + Organization? organization, + string provider, + OrganizationUser? orgUser, + User user) { - User user = null; + // Lazily get organization if not already known + organization ??= await GetOrganizationByProviderAsync(provider); + + // Lazily get the org user if not already known + orgUser ??= await GetOrganizationUserByUserAndOrgIdOrEmailAsync( + user, + organization.Id, + user.Email); + + if (orgUser != null) + { + // Invited is allowed at this point because we know the user is trying to accept an org invite. + EnforceAllowedOrgUserStatus( + orgUser.Status, + allowedStatuses: [ + OrganizationUserStatusType.Invited, + OrganizationUserStatusType.Accepted, + OrganizationUserStatusType.Confirmed, + ], + organization.DisplayName()); + } + else + { + throw new Exception(_i18nService.T("CouldNotFindOrganizationUser", user.Id, organization.Id)); + } + } + + private async Task GetUserFromManualLinkingDataAsync(string userIdentifier) + { + User? user = null; var split = userIdentifier.Split(","); if (split.Length < 2) { throw new Exception(_i18nService.T("InvalidUserIdentifier")); } + var userId = split[0]; var token = split[1]; @@ -611,64 +774,94 @@ public class AccountController : Controller throw new Exception(_i18nService.T("UserIdAndTokenMismatch")); } } + return user; } - private async Task<(Organization, OrganizationUser)> FindOrganizationUser(User existingUser, string email, Guid orgId) + /// + /// Tries to get the organization by the provider which is org id for us as we use the scheme + /// to identify organizations - not identity providers. + /// + /// Org id string from SSO scheme property + /// Errors if the provider string is not a valid org id guid or if the org cannot be found by the id. + private async Task GetOrganizationByProviderAsync(string provider) { - OrganizationUser orgUser = null; - var organization = await _organizationRepository.GetByIdAsync(orgId); + if (!Guid.TryParse(provider, out var organizationId)) + { + // TODO: support non-org (server-wide) SSO in the future? + throw new Exception(_i18nService.T("SSOProviderIsNotAnOrgId", provider)); + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId); + if (organization == null) { - throw new Exception(_i18nService.T("CouldNotFindOrganization", orgId)); + throw new Exception(_i18nService.T("CouldNotFindOrganization", organizationId)); } + return organization; + } + + /// + /// Attempts to get an for a given organization + /// by first checking for an existing user relationship, and if none is found, + /// by looking up an invited user via their email address. + /// + /// The existing user entity to be looked up in OrganizationUsers table. + /// Organization id from the provider data. + /// Email to use as a fallback in case of an invited user not in the Org Users + /// table yet. + private async Task GetOrganizationUserByUserAndOrgIdOrEmailAsync( + User? user, + Guid organizationId, + string? email) + { + OrganizationUser? orgUser = null; + // Try to find OrgUser via existing User Id. // This covers any OrganizationUser state after they have accepted an invite. - if (existingUser != null) + if (user != null) { - var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(existingUser.Id); - orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == orgId); + var orgUsersByUserId = await _organizationUserRepository.GetManyByUserAsync(user.Id); + orgUser = orgUsersByUserId.SingleOrDefault(u => u.OrganizationId == organizationId); } // If no Org User found by Existing User Id - search all the organization's users via email. // This covers users who are Invited but haven't accepted their invite yet. - orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(orgId, email); + if (email != null) + { + orgUser ??= await _organizationUserRepository.GetByOrganizationEmailAsync(organizationId, email); + } - return (organization, orgUser); + return orgUser; } - private void EnsureOrgUserStatusAllowed( - OrganizationUserStatusType status, - string organizationDisplayName, - params OrganizationUserStatusType[] allowedStatuses) + private void EnforceAllowedOrgUserStatus( + OrganizationUserStatusType statusToCheckAgainst, + OrganizationUserStatusType[] allowedStatuses, + string organizationDisplayNameForLogging) { // if this status is one of the allowed ones, just return - if (allowedStatuses.Contains(status)) + if (allowedStatuses.Contains(statusToCheckAgainst)) { return; } // otherwise throw the appropriate exception - switch (status) + switch (statusToCheckAgainst) { - case OrganizationUserStatusType.Invited: - // Org User is invited – must accept via email first - throw new Exception( - _i18nService.T("AcceptInviteBeforeUsingSSO", organizationDisplayName)); case OrganizationUserStatusType.Revoked: // Revoked users may not be (auto)‑provisioned throw new Exception( - _i18nService.T("OrganizationUserAccessRevoked", organizationDisplayName)); + _i18nService.T("OrganizationUserAccessRevoked", organizationDisplayNameForLogging)); default: // anything else is “unknown” throw new Exception( - _i18nService.T("OrganizationUserUnknownStatus", organizationDisplayName)); + _i18nService.T("OrganizationUserUnknownStatus", organizationDisplayNameForLogging)); } } - - private IActionResult InvalidJson(string errorMessageKey, Exception ex = null) + private IActionResult InvalidJson(string errorMessageKey, Exception? ex = null) { Response.StatusCode = ex == null ? 400 : 500; return Json(new ErrorResponseModel(_i18nService.T(errorMessageKey)) @@ -679,13 +872,13 @@ public class AccountController : Controller }); } - private string GetEmailAddress(IEnumerable claims, IEnumerable additionalClaimTypes) + private string? TryGetEmailAddressFromClaims(IEnumerable claims, IEnumerable additionalClaimTypes) { var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value) && c.Value.Contains("@")); var email = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, - SamlClaimTypes.Email, "mail", "emailaddress"); + filteredClaims.GetFirstMatch(JwtClaimTypes.Email, ClaimTypes.Email, + SamlClaimTypes.Email, "mail", "emailaddress"); if (!string.IsNullOrWhiteSpace(email)) { return email; @@ -701,13 +894,15 @@ public class AccountController : Controller return null; } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable private string GetName(IEnumerable claims, IEnumerable additionalClaimTypes) { var filteredClaims = claims.Where(c => !string.IsNullOrWhiteSpace(c.Value)); var name = filteredClaims.GetFirstMatch(additionalClaimTypes.ToArray()) ?? - filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, - SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); + filteredClaims.GetFirstMatch(JwtClaimTypes.Name, ClaimTypes.Name, + SamlClaimTypes.DisplayName, SamlClaimTypes.CommonName, "displayname", "cn"); if (!string.IsNullOrWhiteSpace(name)) { return name; @@ -724,8 +919,10 @@ public class AccountController : Controller return null; } +#nullable restore - private async Task CreateSsoUserRecord(string providerUserId, Guid userId, Guid orgId, OrganizationUser orgUser) + private async Task CreateSsoUserRecordAsync(string providerUserId, Guid userId, Guid orgId, + OrganizationUser orgUser) { // Delete existing SsoUser (if any) - avoids error if providerId has changed and the sso link is stale var existingSsoUser = await _ssoUserRepository.GetByUserIdOrganizationIdAsync(orgId, userId); @@ -740,15 +937,12 @@ public class AccountController : Controller await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_FirstSsoLogin); } - var ssoUser = new SsoUser - { - ExternalId = providerUserId, - UserId = userId, - OrganizationId = orgId, - }; + var ssoUser = new SsoUser { ExternalId = providerUserId, UserId = userId, OrganizationId = orgId, }; await _ssoUserRepository.CreateAsync(ssoUser); } + // FIXME: Update this file to be null safe and then delete the line below +#nullable disable private void ProcessLoginCallback(AuthenticateResult externalResult, List localClaims, AuthenticationProperties localSignInProps) { @@ -769,18 +963,6 @@ public class AccountController : Controller } } - private async Task GetProviderAsync(string returnUrl) - { - var context = await _interaction.GetAuthorizationContextAsync(returnUrl); - if (context?.IdP != null && await _schemeProvider.GetSchemeAsync(context.IdP) != null) - { - return context.IdP; - } - var schemes = await _schemeProvider.GetAllSchemesAsync(); - var providers = schemes.Select(x => x.Name).ToList(); - return providers.FirstOrDefault(); - } - private async Task<(string, string, string)> GetLoggedOutDataAsync(string logoutId) { // Get context information (client name, post logout redirect URI and iframe for federated signout) @@ -811,10 +993,31 @@ public class AccountController : Controller return (logoutId, logout?.PostLogoutRedirectUri, externalAuthenticationScheme); } +#nullable restore + + /** + * Tries to get a user's email from the claims and SSO configuration data or the provider user id if + * the claims email extraction returns null. + */ + private string? TryGetEmailAddress( + IEnumerable claims, + SsoConfigurationData config, + string providerUserId) + { + var email = TryGetEmailAddressFromClaims(claims, config.GetAdditionalEmailClaimTypes()); + + // If email isn't populated from claims and providerUserId has @, assume it is the email. + if (string.IsNullOrWhiteSpace(email) && providerUserId.Contains("@")) + { + email = providerUserId; + } + + return email; + } public bool IsNativeClient(DIM.AuthorizationRequest context) { return !context.RedirectUri.StartsWith("https", StringComparison.Ordinal) - && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); + && !context.RedirectUri.StartsWith("http", StringComparison.Ordinal); } } diff --git a/bitwarden_license/src/Sso/Sso.csproj b/bitwarden_license/src/Sso/Sso.csproj index 1b6b666ab1..2a1c14ae5a 100644 --- a/bitwarden_license/src/Sso/Sso.csproj +++ b/bitwarden_license/src/Sso/Sso.csproj @@ -10,7 +10,7 @@ - + diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index 3aeb9c6beb..3ae8883ac4 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -157,6 +157,6 @@ public class Startup app.UseEndpoints(endpoints => endpoints.MapDefaultControllerRoute()); // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/bitwarden_license/src/Sso/package-lock.json b/bitwarden_license/src/Sso/package-lock.json index aeefbd69d7..f5e0468f87 100644 --- a/bitwarden_license/src/Sso/package-lock.json +++ b/bitwarden_license/src/Sso/package-lock.json @@ -17,9 +17,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } }, @@ -678,6 +678,7 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -704,6 +705,7 @@ "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -746,6 +748,16 @@ "ajv": "^8.8.2" } }, + "node_modules/baseline-browser-mapping": { + "version": "2.8.18", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.18.tgz", + "integrity": "sha512-UYmTpOBwgPScZpS4A+YbapwWuBwasxvO/2IOHArSsAhL/+ZdmATBXTex3t+l2hXwLVYK382ibr/nKoY9GKe86w==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/bootstrap": { "version": "5.3.6", "resolved": "https://registry.npmjs.org/bootstrap/-/bootstrap-5.3.6.tgz", @@ -780,9 +792,9 @@ } }, "node_modules/browserslist": { - "version": "4.25.4", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.25.4.tgz", - "integrity": "sha512-4jYpcjabC606xJ3kw2QwGEZKX0Aw7sgQdZCvIK9dhVSPh76BKo+C+btT1RRofH7B+8iNpEbgGNVWiLki5q93yg==", + "version": "4.26.3", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.26.3.tgz", + "integrity": "sha512-lAUU+02RFBuCKQPj/P6NgjlbCnLBMp4UtgTx7vNHd3XSIJF87s9a5rA3aH2yw3GS9DqZAUbOtZdCCiZeVRqt0w==", "dev": true, "funding": [ { @@ -799,10 +811,12 @@ } ], "license": "MIT", + "peer": true, "dependencies": { - "caniuse-lite": "^1.0.30001737", - "electron-to-chromium": "^1.5.211", - "node-releases": "^2.0.19", + "baseline-browser-mapping": "^2.8.9", + "caniuse-lite": "^1.0.30001746", + "electron-to-chromium": "^1.5.227", + "node-releases": "^2.0.21", "update-browserslist-db": "^1.1.3" }, "bin": { @@ -820,9 +834,9 @@ "license": "MIT" }, "node_modules/caniuse-lite": { - "version": "1.0.30001741", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001741.tgz", - "integrity": "sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==", + "version": "1.0.30001751", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001751.tgz", + "integrity": "sha512-A0QJhug0Ly64Ii3eIqHu5X51ebln3k4yTUkY1j8drqpWHVreg/VLijN48cZ1bYPiqOQuqpkIKnzr/Ul8V+p6Cw==", "dev": true, "funding": [ { @@ -974,9 +988,9 @@ } }, "node_modules/electron-to-chromium": { - "version": "1.5.215", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.215.tgz", - "integrity": "sha512-TIvGp57UpeNetj/wV/xpFNpWGb0b/ROw372lHPx5Aafx02gjTBtWnEEcaSX3W2dLM3OSdGGyHX/cHl01JQsLaQ==", + "version": "1.5.237", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.237.tgz", + "integrity": "sha512-icUt1NvfhGLar5lSWH3tHNzablaA5js3HVHacQimfP8ViEBOQv+L7DKEuHdbTZ0SKCO1ogTJTIL1Gwk9S6Qvcg==", "dev": true, "license": "ISC" }, @@ -1527,9 +1541,9 @@ "optional": true }, "node_modules/node-releases": { - "version": "2.0.20", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.20.tgz", - "integrity": "sha512-7gK6zSXEH6neM212JgfYFXe+GmZQM+fia5SsusuBIUgnPheLFBmIPhtFoAQRj8/7wASYQnbDlHPVwY0BefoFgA==", + "version": "2.0.26", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.26.tgz", + "integrity": "sha512-S2M9YimhSjBSvYnlr5/+umAnPHE++ODwt5e2Ij6FoX45HA/s4vHdkDx1eax2pAPeAOqu4s9b7ppahsyEFdVqQA==", "dev": true, "license": "MIT" }, @@ -1653,6 +1667,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -1859,11 +1874,12 @@ "license": "MIT" }, "node_modules/sass": { - "version": "1.91.0", - "resolved": "https://registry.npmjs.org/sass/-/sass-1.91.0.tgz", - "integrity": "sha512-aFOZHGf+ur+bp1bCHZ+u8otKGh77ZtmFyXDo4tlYvT7PWql41Kwd8wdkPqhhT+h2879IVblcHFglIMofsFd1EA==", + "version": "1.93.2", + "resolved": "https://registry.npmjs.org/sass/-/sass-1.93.2.tgz", + "integrity": "sha512-t+YPtOQHpGW1QWsh1CHQ5cPIr9lbbGZLZnbihP/D/qZj/yuV68m8qarcV17nvkOX81BCrvzAlq2klCQFZghyTg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "chokidar": "^4.0.0", "immutable": "^5.0.2", @@ -1921,9 +1937,9 @@ } }, "node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "license": "MIT", "dependencies": { @@ -2060,9 +2076,9 @@ } }, "node_modules/tapable": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.3.tgz", - "integrity": "sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "license": "MIT", "engines": { @@ -2201,11 +2217,12 @@ } }, "node_modules/webpack": { - "version": "5.101.3", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.101.3.tgz", - "integrity": "sha512-7b0dTKR3Ed//AD/6kkx/o7duS8H3f1a4w3BYpIriX4BzIhjkn4teo05cptsxvLesHFKK5KObnadmCHBwGc+51A==", + "version": "5.102.1", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.102.1.tgz", + "integrity": "sha512-7h/weGm9d/ywQ6qzJ+Xy+r9n/3qgp/thalBbpOi5i223dPXKi04IBtqPN9nTd+jBc7QKfvDbaBnFipYp4sJAUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/eslint-scope": "^3.7.7", "@types/estree": "^1.0.8", @@ -2215,7 +2232,7 @@ "@webassemblyjs/wasm-parser": "^1.14.1", "acorn": "^8.15.0", "acorn-import-phases": "^1.0.3", - "browserslist": "^4.24.0", + "browserslist": "^4.26.3", "chrome-trace-event": "^1.0.2", "enhanced-resolve": "^5.17.3", "es-module-lexer": "^1.2.1", @@ -2227,10 +2244,10 @@ "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^4.3.2", - "tapable": "^2.1.1", + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", "terser-webpack-plugin": "^5.3.11", - "watchpack": "^2.4.1", + "watchpack": "^2.4.4", "webpack-sources": "^3.3.3" }, "bin": { @@ -2255,6 +2272,7 @@ "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@discoveryjs/json-ext": "^0.5.0", "@webpack-cli/configtest": "^2.1.1", diff --git a/bitwarden_license/src/Sso/package.json b/bitwarden_license/src/Sso/package.json index 28f40f0d25..df46444aca 100644 --- a/bitwarden_license/src/Sso/package.json +++ b/bitwarden_license/src/Sso/package.json @@ -16,9 +16,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } } diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs index 9b9c41048b..2bb02c3cee 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/ProviderFeatures/RemoveOrganizationFromProviderCommandTests.cs @@ -156,16 +156,18 @@ public class RemoveOrganizationFromProviderCommandTests "b@example.com" ]); - sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId) - .Returns(GetSubscription(organization.GatewaySubscriptionId)); + sutProvider.GetDependency().SubscriptionGetAsync(organization.GatewaySubscriptionId, Arg.Is( + options => options.Expand.Contains("customer"))) + .Returns(GetSubscription(organization.GatewaySubscriptionId, organization.GatewayCustomerId)); await sutProvider.Sut.RemoveOrganizationFromProvider(provider, providerOrganization, organization); var stripeAdapter = sutProvider.GetDependency(); await stripeAdapter.Received(1).CustomerUpdateAsync(organization.GatewayCustomerId, - Arg.Is(options => - options.Coupon == string.Empty && options.Email == "a@example.com")); + Arg.Is(options => options.Email == "a@example.com")); + + await stripeAdapter.Received(1).CustomerDeleteDiscountAsync(organization.GatewayCustomerId); await stripeAdapter.Received(1).SubscriptionUpdateAsync(organization.GatewaySubscriptionId, Arg.Is(options => @@ -368,10 +370,21 @@ public class RemoveOrganizationFromProviderCommandTests Arg.Is>(emails => emails.FirstOrDefault() == "a@example.com")); } - private static Subscription GetSubscription(string subscriptionId) => + private static Subscription GetSubscription(string subscriptionId, string customerId) => new() { Id = subscriptionId, + CustomerId = customerId, + Customer = new Customer + { + Discount = new Discount + { + Coupon = new Coupon + { + Id = "coupon-id" + } + } + }, Status = StripeConstants.SubscriptionStatus.Active, Items = new StripeList { diff --git a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs index f2ba2fab8f..e61cf5f97e 100644 --- a/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/AdminConsole/Services/ProviderServiceTests.cs @@ -9,7 +9,7 @@ using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Context; @@ -41,7 +41,7 @@ public class ProviderServiceTests public async Task CompleteSetupAsync_UserIdIsInvalid_Throws(SutProvider sutProvider) { var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null)); + () => sutProvider.Sut.CompleteSetupAsync(default, default, default, default, null, null)); Assert.Contains("Invalid owner.", exception.Message); } @@ -53,83 +53,12 @@ public class ProviderServiceTests userService.GetUserByIdAsync(user.Id).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null)); + () => sutProvider.Sut.CompleteSetupAsync(provider, user.Id, default, default, null, null)); Assert.Contains("Invalid token.", exception.Message); } [Theory, BitAutoData] - public async Task CompleteSetupAsync_InvalidTaxInfo_ThrowsBadRequestException( - User user, - Provider provider, - string key, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource, - [ProviderUser] ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - taxInfo.BillingAddressCountry = null; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); - - Assert.Equal("Both address and postal code are required to set up your provider.", exception.Message); - } - - [Theory, BitAutoData] - public async Task CompleteSetupAsync_InvalidTokenizedPaymentSource_ThrowsBadRequestException( - User user, - Provider provider, - string key, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource, - [ProviderUser] ProviderUser providerUser, - SutProvider sutProvider) - { - providerUser.ProviderId = provider.Id; - providerUser.UserId = user.Id; - var userService = sutProvider.GetDependency(); - userService.GetUserByIdAsync(user.Id).Returns(user); - - var providerUserRepository = sutProvider.GetDependency(); - providerUserRepository.GetByProviderUserAsync(provider.Id, user.Id).Returns(providerUser); - - var dataProtectionProvider = DataProtectionProvider.Create("ApplicationName"); - var protector = dataProtectionProvider.CreateProtector("ProviderServiceDataProtector"); - sutProvider.GetDependency().CreateProtector("ProviderServiceDataProtector") - .Returns(protector); - - sutProvider.Create(); - - var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - - - tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; - - var exception = await Assert.ThrowsAsync(() => - sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource)); - - Assert.Equal("A payment method is required to set up your provider.", exception.Message); - } - - [Theory, BitAutoData] - public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource, + public async Task CompleteSetupAsync_Success(User user, Provider provider, string key, TokenizedPaymentMethod tokenizedPaymentMethod, BillingAddress billingAddress, [ProviderUser] ProviderUser providerUser, SutProvider sutProvider) { @@ -149,7 +78,7 @@ public class ProviderServiceTests var providerBillingService = sutProvider.GetDependency(); var customer = new Customer { Id = "customer_id" }; - providerBillingService.SetupCustomer(provider, taxInfo, tokenizedPaymentSource).Returns(customer); + providerBillingService.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress).Returns(customer); var subscription = new Subscription { Id = "subscription_id" }; providerBillingService.SetupSubscription(provider).Returns(subscription); @@ -158,7 +87,7 @@ public class ProviderServiceTests var token = protector.Protect($"ProviderSetupInvite {provider.Id} {user.Email} {CoreHelpers.ToEpocMilliseconds(DateTime.UtcNow)}"); - await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, taxInfo, tokenizedPaymentSource); + await sutProvider.Sut.CompleteSetupAsync(provider, user.Id, token, key, tokenizedPaymentMethod, billingAddress); await sutProvider.GetDependency().Received().UpsertAsync(Arg.Is( p => diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs index 54c0b82aa9..18c71364e6 100644 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/Providers/Services/ProviderBillingServiceTests.cs @@ -1,5 +1,4 @@ using System.Globalization; -using System.Net; using Bit.Commercial.Core.Billing.Providers.Models; using Bit.Commercial.Core.Billing.Providers.Services; using Bit.Core.AdminConsole.Entities; @@ -10,18 +9,16 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Providers.Repositories; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -895,118 +892,53 @@ public class ProviderBillingServiceTests #region SetupCustomer [Theory, BitAutoData] - public async Task SetupCustomer_MissingCountry_ContactSupport( + public async Task SetupCustomer_NullPaymentMethod_ThrowsNullReferenceException( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + BillingAddress billingAddress) { - taxInfo.BillingAddressCountry = null; - - await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - [Theory, BitAutoData] - public async Task SetupCustomer_MissingPostalCode_ContactSupport( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) - { - taxInfo.BillingAddressCountry = null; - - await ThrowsBillingExceptionAsync(() => sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .CustomerGetAsync(Arg.Any(), Arg.Any()); - } - - - [Theory, BitAutoData] - public async Task SetupCustomer_NullPaymentSource_ThrowsArgumentNullException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo) - { - await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, null)); - } - - [Theory, BitAutoData] - public async Task SetupCustomer_InvalidRequiredPaymentMethod_ThrowsBillingException( - SutProvider sutProvider, - Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) - { - provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; - - - tokenizedPaymentSource = tokenizedPaymentSource with { Type = PaymentMethodType.BitPay }; - - await ThrowsBillingExceptionAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + await Assert.ThrowsAsync(() => + sutProvider.Sut.SetupCustomer(provider, null, billingAddress)); } [Theory, BitAutoData] public async Task SetupCustomer_WithBankAccount_Error_Reverts( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); - - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; stripeAdapter.SetupIntentList(Arg.Is(options => - options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Throws(); sutProvider.GetDependency().GetSetupIntentIdForSubscriber(provider.Id).Returns("setup_intent_id"); await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); await sutProvider.GetDependency().Received(1).Set(provider.Id, "setup_intent_id"); @@ -1020,45 +952,37 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithPayPal_Error_Reverts( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); - - - sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && o.Metadata["btCustomerId"] == "braintree_customer_id" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Throws(); await Assert.ThrowsAsync(() => - sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); await sutProvider.GetDependency().Customer.Received(1).DeleteAsync("braintree_customer_id"); } @@ -1067,17 +991,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithBankAccount_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1087,31 +1005,30 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.BankAccount, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.BankAccount, Token = "token" }; stripeAdapter.SetupIntentList(Arg.Is(options => - options.PaymentMethod == tokenizedPaymentSource.Token)).Returns([ + options.PaymentMethod == tokenizedPaymentMethod.Token)).Returns([ new SetupIntent { Id = "setup_intent_id" } ]); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); @@ -1122,17 +1039,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithPayPal_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1142,30 +1053,29 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.PayPal, "token"); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "token" }; - - sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentSource.Token) + sutProvider.GetDependency().CreateBraintreeCustomer(provider, tokenizedPaymentMethod.Token) .Returns("braintree_customer_id"); stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && o.Metadata["btCustomerId"] == "braintree_customer_id" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } @@ -1174,17 +1084,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithCard_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "12345678Z"); var stripeAdapter = sutProvider.GetDependency(); @@ -1194,28 +1098,26 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.PaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber)) + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } @@ -1224,17 +1126,11 @@ public class ProviderBillingServiceTests public async Task SetupCustomer_WithCard_ReverseCharge_Success( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo) + BillingAddress billingAddress) { provider.Name = "MSP"; - - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns(taxInfo.TaxIdType); - - taxInfo.BillingAddressCountry = "AD"; + billingAddress.Country = "FR"; // Non-US country to trigger reverse charge + billingAddress.TaxId = new TaxID("fr_siren", "123456789"); var stripeAdapter = sutProvider.GetDependency(); @@ -1244,55 +1140,51 @@ public class ProviderBillingServiceTests Tax = new CustomerTax { AutomaticTax = StripeConstants.AutomaticTaxStatus.Supported } }; - var tokenizedPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "token"); - + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; stripeAdapter.CustomerCreateAsync(Arg.Is(o => - o.Address.Country == taxInfo.BillingAddressCountry && - o.Address.PostalCode == taxInfo.BillingAddressPostalCode && - o.Address.Line1 == taxInfo.BillingAddressLine1 && - o.Address.Line2 == taxInfo.BillingAddressLine2 && - o.Address.City == taxInfo.BillingAddressCity && - o.Address.State == taxInfo.BillingAddressState && - o.Description == WebUtility.HtmlDecode(provider.BusinessName) && + o.Address.Country == billingAddress.Country && + o.Address.PostalCode == billingAddress.PostalCode && + o.Address.Line1 == billingAddress.Line1 && + o.Address.Line2 == billingAddress.Line2 && + o.Address.City == billingAddress.City && + o.Address.State == billingAddress.State && + o.Description == provider.DisplayBusinessName() && o.Email == provider.BillingEmail && - o.PaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentSource.Token && - o.InvoiceSettings.CustomFields.FirstOrDefault().Name == "Provider" && - o.InvoiceSettings.CustomFields.FirstOrDefault().Value == "MSP" && + o.InvoiceSettings.DefaultPaymentMethod == tokenizedPaymentMethod.Token && + o.InvoiceSettings.CustomFields.FirstOrDefault().Name == provider.SubscriberType() && + o.InvoiceSettings.CustomFields.FirstOrDefault().Value == provider.DisplayName() && o.Metadata["region"] == "" && - o.TaxIdData.FirstOrDefault().Type == taxInfo.TaxIdType && - o.TaxIdData.FirstOrDefault().Value == taxInfo.TaxIdNumber && + o.TaxIdData.FirstOrDefault().Type == billingAddress.TaxId.Code && + o.TaxIdData.FirstOrDefault().Value == billingAddress.TaxId.Value && o.TaxExempt == StripeConstants.TaxExempt.Reverse)) .Returns(expected); - var actual = await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource); + var actual = await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress); Assert.Equivalent(expected, actual); } [Theory, BitAutoData] - public async Task SetupCustomer_Throws_BadRequestException_WhenTaxIdIsInvalid( + public async Task SetupCustomer_WithInvalidTaxId_ThrowsBadRequestException( SutProvider sutProvider, Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource) + BillingAddress billingAddress) { provider.Name = "MSP"; + billingAddress.Country = "AD"; + billingAddress.TaxId = new TaxID("es_nif", "invalid_tax_id"); - taxInfo.BillingAddressCountry = "AD"; + var stripeAdapter = sutProvider.GetDependency(); + var tokenizedPaymentMethod = new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.Card, Token = "token" }; - sutProvider.GetDependency() - .GetStripeTaxCode(Arg.Is( - p => p == taxInfo.BillingAddressCountry), - Arg.Is(p => p == taxInfo.TaxIdNumber)) - .Returns((string)null); + stripeAdapter.CustomerCreateAsync(Arg.Any()) + .Throws(new StripeException("Invalid tax ID") { StripeError = new StripeError { Code = "tax_id_invalid" } }); var actual = await Assert.ThrowsAsync(async () => - await sutProvider.Sut.SetupCustomer(provider, taxInfo, tokenizedPaymentSource)); + await sutProvider.Sut.SetupCustomer(provider, tokenizedPaymentMethod, billingAddress)); - Assert.IsType(actual); - Assert.Equal("billingTaxIdTypeInferenceError", actual.Message); + Assert.Equal("Your tax ID wasn't recognized for your selected country. Please ensure your country and tax ID are valid.", actual.Message); } #endregion diff --git a/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs new file mode 100644 index 0000000000..0fe37d89fd --- /dev/null +++ b/bitwarden_license/test/SSO.Test/Controllers/AccountControllerTest.cs @@ -0,0 +1,1011 @@ +using System.Reflection; +using System.Security.Claims; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Sso.Controllers; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Duende.IdentityModel; +using Duende.IdentityServer.Configuration; +using Duende.IdentityServer.Models; +using Duende.IdentityServer.Services; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using NSubstitute; +using Xunit.Abstractions; +using AuthenticationOptions = Duende.IdentityServer.Configuration.AuthenticationOptions; + +namespace Bit.SSO.Test.Controllers; + +[ControllerCustomize(typeof(AccountController)), SutProviderCustomize] +public class AccountControllerTest +{ + private readonly ITestOutputHelper _output; + + public AccountControllerTest(ITestOutputHelper output) + { + _output = output; + } + + private static IAuthenticationService SetupHttpContextWithAuth( + SutProvider sutProvider, + AuthenticateResult authResult, + IAuthenticationService? authService = null) + { + var schemeProvider = Substitute.For(); + schemeProvider.GetDefaultAuthenticateSchemeAsync() + .Returns(new AuthenticationScheme("idsrv", "idsrv", typeof(IAuthenticationHandler))); + + var resolvedAuthService = authService ?? Substitute.For(); + resolvedAuthService.AuthenticateAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme) + .Returns(authResult); + + var services = new ServiceCollection(); + services.AddSingleton(resolvedAuthService); + services.AddSingleton(schemeProvider); + services.AddSingleton(new IdentityServerOptions + { + Authentication = new AuthenticationOptions + { + CookieAuthenticationScheme = "idsrv" + } + }); + var sp = services.BuildServiceProvider(); + + sutProvider.Sut.ControllerContext = new ControllerContext + { + HttpContext = new DefaultHttpContext + { + RequestServices = sp + } + }; + + return resolvedAuthService; + } + + private static AuthenticateResult BuildSuccessfulExternalAuth(Guid orgId, string providerUserId, string email) + { + var claims = new[] + { + new Claim(JwtClaimTypes.Subject, providerUserId), + new Claim(JwtClaimTypes.Email, email) + }; + var principal = new ClaimsPrincipal(new ClaimsIdentity(claims, "External")); + var properties = new AuthenticationProperties + { + Items = + { + ["scheme"] = orgId.ToString(), + ["return_url"] = "~/", + ["state"] = "state", + ["user_identifier"] = string.Empty + } + }; + var ticket = new AuthenticationTicket(principal, properties, AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme); + return AuthenticateResult.Success(ticket); + } + + private static void ConfigureSsoAndUser( + SutProvider sutProvider, + Guid orgId, + string providerUserId, + User user, + Organization? organization = null, + OrganizationUser? orgUser = null) + { + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + if (organization != null) + { + organizationRepository.GetByIdAsync(orgId).Returns(organization); + } + if (organization != null && orgUser != null) + { + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id).Returns(orgUser); + organizationUserRepository.GetManyByUserAsync(user.Id).Returns([orgUser]); + } + } + + private enum MeasurementScenario + { + ExistingSsoLinkedAccepted, + ExistingUserNoOrgUser, + JitProvision + } + + private sealed class LookupCounts + { + public int UserGetBySso { get; init; } + public int UserGetByEmail { get; init; } + public int OrgGetById { get; init; } + public int OrgUserGetByOrg { get; init; } + public int OrgUserGetByEmail { get; init; } + } + + private async Task MeasureCountsForScenarioAsync( + SutProvider sutProvider, + MeasurementScenario scenario, + bool preventNonCompliant) + { + var orgId = Guid.NewGuid(); + var providerUserId = $"meas-{scenario}-{(preventNonCompliant ? "on" : "off")}"; + var email = scenario == MeasurementScenario.JitProvision + ? "jit.compare@example.com" + : "existing.compare@example.com"; + + var organization = new Organization { Id = orgId, Name = "Org" }; + var user = new User { Id = Guid.NewGuid(), Email = email }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + // SSO config present + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + var featureService = sutProvider.GetDependency(); + var interactionService = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + switch (scenario) + { + case MeasurementScenario.ExistingSsoLinkedAccepted: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id) + .Returns(new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }); + break; + case MeasurementScenario.ExistingUserNoOrgUser: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns(user); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationAsync(organization.Id, user.Id) + .Returns((OrganizationUser?)null); + break; + case MeasurementScenario.JitProvision: + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + userRepository.GetByEmailAsync(email).Returns((User?)null); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationEmailAsync(orgId, email) + .Returns((OrganizationUser?)null); + break; + } + + featureService.IsEnabled(Arg.Any()).Returns(preventNonCompliant); + interactionService.GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // Ignore exceptions for measurement; some flows can throw based on status enforcement + } + + var counts = new LookupCounts + { + UserGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)), + UserGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)), + OrgGetById = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)), + OrgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)), + OrgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)), + }; + + userRepository.ClearReceivedCalls(); + organizationRepository.ClearReceivedCalls(); + organizationUserRepository.ClearReceivedCalls(); + + return counts; + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_NoOrgUser_ThrowsCouldNotFindOrganizationUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-missing-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "missing.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + // i18n returns the key so we can assert on message contents + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // SSO config + user link exists, but no org user membership + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser: null); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, user.Id).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("CouldNotFindOrganizationUser", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserInvited_AllowsLogin( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-invited-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "invited.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await authService.Received().SignOutAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + Arg.Any()); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserRevoked_ThrowsAccessRevoked( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-revoked-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "revoked.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Revoked, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("OrganizationUserAccessRevoked", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_OrgUserUnknown_ThrowsUnknown( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-unknown-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "unknown.orguser@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var unknownStatus = (OrganizationUserStatusType)999; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = unknownStatus, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + Assert + var ex = await Assert.ThrowsAsync(() => sutProvider.Sut.ExternalCallback()); + Assert.Equal("OrganizationUserUnknownStatus", ex.Message); + } + + [Theory, BitAutoData] + public async Task ExternalCallback_WithExistingUserAndAcceptedMembership_RedirectsToReturnUrl( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-123"; + var user = new User { Id = Guid.NewGuid(), Email = "user@example.com" }; + var organization = new Organization { Id = orgId, Name = "Test Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await authService.Received().SignOutAsync( + Arg.Any(), + AuthenticationSchemes.BitwardenExternalCookieAuthenticationScheme, + Arg.Any()); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_SkipsOrgLookupAndSignsIn( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-flag-off"; + var user = new User { Id = Guid.NewGuid(), Email = "flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + var authService = SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + var result = await sutProvider.Sut.ExternalCallback(); + + // Assert + var redirect = Assert.IsType(result); + Assert.Equal("~/", redirect.Url); + + await authService.Received().SignInAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetByOrganizationAsync(Guid.Empty, Guid.Empty); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingSsoLinkedAccepted_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing"; + var user = new User { Id = Guid.NewGuid(), Email = "existing@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = user.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email); + SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // ignore for measurement only + } + + // Assert (measurement only - no asserts on counts) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)) + + organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetManyByUserAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(0, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(1, orgUserGetByOrg); + Assert.Equal(0, orgUserGetByEmail); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_JitProvision_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-jit"; + var email = "jit.measure@example.com"; + var organization = new Organization { Id = orgId, Name = "Org", Seats = null }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfigRepository = sutProvider.GetDependency(); + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + ssoConfigRepository.GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + // JIT (no existing user or sso link) + userRepository.GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + userRepository.GetByEmailAsync(email).Returns((User?)null); + organizationRepository.GetByIdAsync(orgId).Returns(organization); + organizationUserRepository.GetByOrganizationEmailAsync(orgId, email).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // JIT path may throw due to Invited status under enforcement; ignore for measurement + } + + // Assert (measurement only - no asserts on counts) + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(1, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(0, orgUserGetByOrg); + Assert.Equal(1, orgUserGetByEmail); + } + + /// + /// PM-24579: Permanent test, remove the True in PreventNonCompliantTrue and remove the configure for the feature + /// flag. + /// + /// This test will trigger both the GetByOrganizationAsync and the fallback attempt to get by email + /// GetByOrganizationEmailAsync. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantTrue_ExistingUser_NoOrgUser_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-no-orguser"; + var user = new User { Id = Guid.NewGuid(), Email = "existing2@example.com" }; + var organization = new Organization { Id = orgId, Name = "Org" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + ConfigureSsoAndUser( + sutProvider, + orgId, + providerUserId, + user, + organization, + orgUser: null); + + // Ensure orgUser lookup returns null + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, user.Id).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(true); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try + { + _ = await sutProvider.Sut.ExternalCallback(); + } + catch + { + // ignore for measurement only + } + + // Assert (measurement only - no asserts on counts) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)) + + organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetManyByUserAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + + // Snapshot assertions + Assert.Equal(1, userGetBySso); + Assert.Equal(0, userGetByEmail); + Assert.Equal(1, orgGet); + Assert.Equal(1, orgUserGetByOrg); + Assert.Equal(1, orgUserGetByEmail); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_ExistingSsoLinkedAccepted_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-flagoff"; + var user = new User { Id = Guid.NewGuid(), Email = "existing.flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_ExistingUser_NoOrgUser_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-existing-no-orguser-flagoff"; + var user = new User { Id = Guid.NewGuid(), Email = "existing2.flagoff@example.com" }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, user.Email!); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns(user); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + /// + /// PM-24579: Temporary test, remove with feature flag. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_PreventNonCompliantFalse_JitProvision_MeasureLookups( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "ext-measure-jit-flagoff"; + var email = "jit.flagoff@example.com"; + var organization = new Organization { Id = orgId, Name = "Org", Seats = null }; + + var authResult = BuildSuccessfulExternalAuth(orgId, providerUserId, email); + SetupHttpContextWithAuth(sutProvider, authResult); + + var ssoConfig = new SsoConfig { OrganizationId = orgId, Enabled = true }; + var ssoData = new SsoConfigurationData(); + ssoConfig.SetData(ssoData); + sutProvider.GetDependency().GetByOrganizationIdAsync(orgId).Returns(ssoConfig); + + // JIT (no existing user or sso link) + sutProvider.GetDependency().GetBySsoUserAsync(providerUserId, orgId).Returns((User?)null); + sutProvider.GetDependency().GetByEmailAsync(email).Returns((User?)null); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email).Returns((OrganizationUser?)null); + + sutProvider.GetDependency().IsEnabled(Arg.Any()).Returns(false); + sutProvider.GetDependency() + .GetAuthorizationContextAsync("~/").Returns((AuthorizationRequest?)null); + + // Act + try { _ = await sutProvider.Sut.ExternalCallback(); } catch { } + + // Assert (measurement) + var userRepository = sutProvider.GetDependency(); + var organizationRepository = sutProvider.GetDependency(); + var organizationUserRepository = sutProvider.GetDependency(); + + var userGetBySso = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetBySsoUserAsync)); + var userGetByEmail = userRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IUserRepository.GetByEmailAsync)); + var orgGet = organizationRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationRepository.GetByIdAsync)); + var orgUserGetByOrg = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationAsync)); + var orgUserGetByEmail = organizationUserRepository.ReceivedCalls().Count(c => c.GetMethodInfo().Name == nameof(IOrganizationUserRepository.GetByOrganizationEmailAsync)); + + _output.WriteLine($"[flag off] GetBySsoUserAsync: {userGetBySso}"); + _output.WriteLine($"[flag off] GetByEmailAsync: {userGetByEmail}"); + _output.WriteLine($"[flag off] GetByIdAsync (Org): {orgGet}"); + _output.WriteLine($"[flag off] GetByOrganizationAsync (OrgUser): {orgUserGetByOrg}"); + _output.WriteLine($"[flag off] GetByOrganizationEmailAsync (OrgUser): {orgUserGetByEmail}"); + } + + [Theory, BitAutoData] + public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingAcceptedUser_CreatesSsoLinkAndReturnsUser( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "provider-user-id"; + var email = "user@example.com"; + var existingUser = new User { Id = Guid.NewGuid(), Email = email }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = existingUser.Id, + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User + }; + + // Arrange repository expectations for the flow + sutProvider.GetDependency().GetByEmailAsync(email).Returns(existingUser); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetManyByUserAsync(existingUser.Id) + .Returns(new List { orgUser }); + sutProvider.GetDependency().GetByOrganizationEmailAsync(orgId, email).Returns(orgUser); + + // No existing SSO link so first SSO login event is logged + sutProvider.GetDependency().GetByUserIdOrganizationIdAsync(orgId, existingUser.Id).Returns((SsoUser?)null); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Jit User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var returned = await task; + + // Assert + Assert.Equal(existingUser.Id, returned.user.Id); + + await sutProvider.GetDependency().Received().CreateAsync(Arg.Is(s => + s.OrganizationId == orgId && s.UserId == existingUser.Id && s.ExternalId == providerUserId)); + + await sutProvider.GetDependency().Received().LogOrganizationUserEventAsync( + orgUser, + EventType.OrganizationUser_FirstSsoLogin); + } + + [Theory, BitAutoData] + public async Task CreateUserAndOrgUserConditionallyAsync_WithExistingInvitedUser_ThrowsAcceptInviteBeforeUsingSSO( + SutProvider sutProvider) + { + // Arrange + var orgId = Guid.NewGuid(); + var providerUserId = "provider-user-id"; + var email = "user@example.com"; + var existingUser = new User { Id = Guid.NewGuid(), Email = email, UsesKeyConnector = false }; + var organization = new Organization { Id = orgId, Name = "Org" }; + var orgUser = new OrganizationUser + { + OrganizationId = orgId, + UserId = existingUser.Id, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }; + + // i18n returns the key so we can assert on message contents + sutProvider.GetDependency() + .T(Arg.Any(), Arg.Any()) + .Returns(ci => (string)ci[0]!); + + // Arrange repository expectations for the flow + sutProvider.GetDependency().GetByEmailAsync(email).Returns(existingUser); + sutProvider.GetDependency().GetByIdAsync(orgId).Returns(organization); + sutProvider.GetDependency().GetManyByUserAsync(existingUser.Id) + .Returns(new List { orgUser }); + + var claims = new[] + { + new Claim(JwtClaimTypes.Email, email), + new Claim(JwtClaimTypes.Name, "Invited User") + } as IEnumerable; + var config = new SsoConfigurationData(); + + var method = typeof(AccountController).GetMethod( + "CreateUserAndOrgUserConditionallyAsync", + BindingFlags.Instance | BindingFlags.NonPublic); + Assert.NotNull(method); + + // Act + Assert + var task = (Task<(User user, Organization organization, OrganizationUser orgUser)>)method.Invoke(sutProvider.Sut, new object[] + { + orgId.ToString(), + providerUserId, + claims, + null!, + config + })!; + + var ex = await Assert.ThrowsAsync(async () => await task); + Assert.Equal("AcceptInviteBeforeUsingSSO", ex.Message); + } + + /// + /// PM-24579: Temporary comparison test to ensure the feature flag ON does not + /// regress lookup counts compared to OFF. When removing the flag, delete this + /// comparison test and keep the specific scenario snapshot tests if desired. + /// + [Theory, BitAutoData] + public async Task ExternalCallback_Measurements_FlagOnVsOff_Comparisons( + SutProvider sutProvider) + { + // Arrange + var scenarios = new[] + { + MeasurementScenario.ExistingSsoLinkedAccepted, + MeasurementScenario.ExistingUserNoOrgUser, + MeasurementScenario.JitProvision + }; + + foreach (var scenario in scenarios) + { + // Act + var onCounts = await MeasureCountsForScenarioAsync(sutProvider, scenario, preventNonCompliant: true); + var offCounts = await MeasureCountsForScenarioAsync(sutProvider, scenario, preventNonCompliant: false); + + // Assert: off should not exceed on in any measured lookup type + Assert.True(offCounts.UserGetBySso <= onCounts.UserGetBySso, $"{scenario}: off UserGetBySso={offCounts.UserGetBySso} > on {onCounts.UserGetBySso}"); + Assert.True(offCounts.UserGetByEmail <= onCounts.UserGetByEmail, $"{scenario}: off UserGetByEmail={offCounts.UserGetByEmail} > on {onCounts.UserGetByEmail}"); + Assert.True(offCounts.OrgGetById <= onCounts.OrgGetById, $"{scenario}: off OrgGetById={offCounts.OrgGetById} > on {onCounts.OrgGetById}"); + Assert.True(offCounts.OrgUserGetByOrg <= onCounts.OrgUserGetByOrg, $"{scenario}: off OrgUserGetByOrg={offCounts.OrgUserGetByOrg} > on {onCounts.OrgUserGetByOrg}"); + Assert.True(offCounts.OrgUserGetByEmail <= onCounts.OrgUserGetByEmail, $"{scenario}: off OrgUserGetByEmail={offCounts.OrgUserGetByEmail} > on {onCounts.OrgUserGetByEmail}"); + + _output.WriteLine($"Scenario={scenario} | ON: SSO={onCounts.UserGetBySso}, Email={onCounts.UserGetByEmail}, Org={onCounts.OrgGetById}, OrgUserByOrg={onCounts.OrgUserGetByOrg}, OrgUserByEmail={onCounts.OrgUserGetByEmail}"); + _output.WriteLine($"Scenario={scenario} | OFF: SSO={offCounts.UserGetBySso}, Email={offCounts.UserGetByEmail}, Org={offCounts.OrgGetById}, OrgUserByOrg={offCounts.OrgUserGetByOrg}, OrgUserByEmail={offCounts.OrgUserGetByEmail}"); + } + } +} diff --git a/bitwarden_license/test/SSO.Test/SSO.Test.csproj b/bitwarden_license/test/SSO.Test/SSO.Test.csproj new file mode 100644 index 0000000000..4b509c9a50 --- /dev/null +++ b/bitwarden_license/test/SSO.Test/SSO.Test.csproj @@ -0,0 +1,35 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + diff --git a/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs b/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs index 1b02e62970..8816885ea7 100644 --- a/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs +++ b/bitwarden_license/test/Scim.Test/Groups/PatchGroupCommandTests.cs @@ -436,7 +436,7 @@ public class PatchGroupCommandTests await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().DeleteUserAsync(default, default); // Assert: logging - sutProvider.GetDependency>().ReceivedWithAnyArgs().LogWarning(default); + sutProvider.GetDependency>().ReceivedWithAnyArgs().LogWarning(""); } [Theory] diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 0ee4aa53a9..c5e42cf9e3 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -53,6 +53,7 @@ services: - ./.data/postgres/log:/var/log/postgresql profiles: - postgres + - ef mysql: image: mysql:8.0 @@ -69,6 +70,7 @@ services: - mysql_dev_data:/var/lib/mysql profiles: - mysql + - ef mariadb: image: mariadb:10 @@ -76,13 +78,13 @@ services: - 4306:3306 environment: MARIADB_USER: maria - MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD} MARIADB_DATABASE: vault_dev MARIADB_RANDOM_ROOT_PASSWORD: "true" volumes: - mariadb_dev_data:/var/lib/mysql profiles: - mariadb + - ef idp: image: kenchan0130/simplesamlphp:1.19.8 @@ -99,7 +101,7 @@ services: - idp rabbitmq: - image: rabbitmq:4.1.0-management + image: rabbitmq:4.1.3-management container_name: rabbitmq ports: - "5672:5672" @@ -153,5 +155,6 @@ volumes: mssql_dev_data: postgres_dev_data: mysql_dev_data: + mariadb_dev_data: rabbitmq_data: redis_data: diff --git a/dev/migrate.ps1 b/dev/migrate.ps1 index 287a2d18ee..26caa87efd 100755 --- a/dev/migrate.ps1 +++ b/dev/migrate.ps1 @@ -70,7 +70,7 @@ Foreach ($item in @( @($mysql, "MySQL", "MySqlMigrations", "mySql", 2), # MariaDB shares the MySQL connection string in the server config so they are mutually exclusive in that context. # However they can still be run independently for integration tests. - @($mariadb, "MariaDB", "MySqlMigrations", "mySql", 3) + @($mariadb, "MariaDB", "MySqlMigrations", "mySql", 4) )) { if (!$item[0] -and !$all) { continue diff --git a/dev/secrets.json.example b/dev/secrets.json.example index 7c91669b39..c6a16846e9 100644 --- a/dev/secrets.json.example +++ b/dev/secrets.json.example @@ -33,6 +33,8 @@ "id": "", "key": "" }, - "licenseDirectory": "" + "licenseDirectory": "", + "enableNewDeviceVerification": true, + "enableEmailVerification": true } } diff --git a/dev/servicebusemulator_config.json b/dev/servicebusemulator_config.json index 294efc1897..bb50c0b1ee 100644 --- a/dev/servicebusemulator_config.json +++ b/dev/servicebusemulator_config.json @@ -3,22 +3,6 @@ "Namespaces": [ { "Name": "sbemulatorns", - "Queues": [ - { - "Name": "queue.1", - "Properties": { - "DeadLetteringOnMessageExpiration": false, - "DefaultMessageTimeToLive": "PT1H", - "DuplicateDetectionHistoryTimeWindow": "PT20S", - "ForwardDeadLetteredMessagesTo": "", - "ForwardTo": "", - "LockDuration": "PT1M", - "MaxDeliveryCount": 3, - "RequiresDuplicateDetection": false, - "RequiresSession": false - } - } - ], "Topics": [ { "Name": "event-logging", @@ -37,6 +21,9 @@ }, { "Name": "events-datadog-subscription" + }, + { + "Name": "events-teams-subscription" } ] }, @@ -98,6 +85,20 @@ } } ] + }, + { + "Name": "integration-teams-subscription", + "Rules": [ + { + "Name": "teams-integration-filter", + "Properties": { + "FilterType": "Correlation", + "CorrelationFilter": { + "Label": "teams" + } + } + } + ] } ] } diff --git a/perf/MicroBenchmarks/MicroBenchmarks.csproj b/perf/MicroBenchmarks/MicroBenchmarks.csproj index 82c526a7d2..a13792b2d6 100644 --- a/perf/MicroBenchmarks/MicroBenchmarks.csproj +++ b/perf/MicroBenchmarks/MicroBenchmarks.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs index 2417bf610d..0d992cb96a 100644 --- a/src/Admin/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Admin/AdminConsole/Controllers/OrganizationsController.cs @@ -472,6 +472,7 @@ public class OrganizationsController : Controller organization.UseRiskInsights = model.UseRiskInsights; organization.UseOrganizationDomains = model.UseOrganizationDomains; organization.UseAdminSponsoredFamilies = model.UseAdminSponsoredFamilies; + organization.UseAutomaticUserConfirmation = model.UseAutomaticUserConfirmation; //secrets organization.SmSeats = model.SmSeats; diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs index b64af3135f..6059a003b6 100644 --- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs @@ -106,6 +106,8 @@ public class OrganizationEditModel : OrganizationViewModel SmServiceAccounts = org.SmServiceAccounts; MaxAutoscaleSmServiceAccounts = org.MaxAutoscaleSmServiceAccounts; UseOrganizationDomains = org.UseOrganizationDomains; + UseAutomaticUserConfirmation = org.UseAutomaticUserConfirmation; + _plans = plans; } @@ -192,6 +194,8 @@ public class OrganizationEditModel : OrganizationViewModel [Display(Name = "Use Organization Domains")] public bool UseOrganizationDomains { get; set; } + [Display(Name = "Automatic User Confirmation")] + public bool UseAutomaticUserConfirmation { get; set; } /** * Creates a Plan[] object for use in Javascript * This is mapped manually below to provide some type safety in case the plan objects change @@ -231,6 +235,7 @@ public class OrganizationEditModel : OrganizationViewModel LegacyYear = p.LegacyYear, Disabled = p.Disabled, SupportsSecretsManager = p.SupportsSecretsManager, + AutomaticUserConfirmation = p.AutomaticUserConfirmation, PasswordManager = new { diff --git a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml index 267264a38f..cb71c0fc78 100644 --- a/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml +++ b/src/Admin/AdminConsole/Views/Shared/_OrganizationForm.cshtml @@ -152,11 +152,15 @@ - @if(FeatureService.IsEnabled(FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) +
+ + +
+ @if(FeatureService.IsEnabled(FeatureFlagKeys.AutomaticConfirmUsers)) {
- - + +
} diff --git a/src/Admin/Billing/Controllers/MigrateProvidersController.cs b/src/Admin/Billing/Controllers/MigrateProvidersController.cs deleted file mode 100644 index ef5ea2312e..0000000000 --- a/src/Admin/Billing/Controllers/MigrateProvidersController.cs +++ /dev/null @@ -1,83 +0,0 @@ -using Bit.Admin.Billing.Models; -using Bit.Admin.Enums; -using Bit.Admin.Utilities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Migration.Services; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Admin.Billing.Controllers; - -[Authorize] -[Route("migrate-providers")] -[SelfHosted(NotSelfHostedOnly = true)] -public class MigrateProvidersController( - IProviderMigrator providerMigrator) : Controller -{ - [HttpGet] - [RequirePermission(Permission.Tools_MigrateProviders)] - public IActionResult Index() - { - return View(new MigrateProvidersRequestModel()); - } - - [HttpPost] - [RequirePermission(Permission.Tools_MigrateProviders)] - [ValidateAntiForgeryToken] - public async Task PostAsync(MigrateProvidersRequestModel request) - { - var providerIds = GetProviderIdsFromInput(request.ProviderIds); - - if (providerIds.Count == 0) - { - return RedirectToAction("Index"); - } - - foreach (var providerId in providerIds) - { - await providerMigrator.Migrate(providerId); - } - - return RedirectToAction("Results", new { ProviderIds = string.Join("\r\n", providerIds) }); - } - - [HttpGet("results")] - [RequirePermission(Permission.Tools_MigrateProviders)] - public async Task ResultsAsync(MigrateProvidersRequestModel request) - { - var providerIds = GetProviderIdsFromInput(request.ProviderIds); - - if (providerIds.Count == 0) - { - return View(Array.Empty()); - } - - var results = await Task.WhenAll(providerIds.Select(providerMigrator.GetResult)); - - return View(results); - } - - [HttpGet("results/{providerId:guid}")] - [RequirePermission(Permission.Tools_MigrateProviders)] - public async Task DetailsAsync([FromRoute] Guid providerId) - { - var result = await providerMigrator.GetResult(providerId); - - if (result == null) - { - return RedirectToAction("Index"); - } - - return View(result); - } - - private static List GetProviderIdsFromInput(string text) => !string.IsNullOrEmpty(text) - ? text.Split( - ["\r\n", "\r", "\n"], - StringSplitOptions.TrimEntries - ) - .Select(id => new Guid(id)) - .ToList() - : []; -} diff --git a/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs b/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs deleted file mode 100644 index 273f934eba..0000000000 --- a/src/Admin/Billing/Models/MigrateProvidersRequestModel.cs +++ /dev/null @@ -1,13 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; - -namespace Bit.Admin.Billing.Models; - -public class MigrateProvidersRequestModel -{ - [Required] - [Display(Name = "Provider IDs")] - public string ProviderIds { get; set; } -} diff --git a/src/Admin/Billing/Views/MigrateProviders/Details.cshtml b/src/Admin/Billing/Views/MigrateProviders/Details.cshtml deleted file mode 100644 index 6ee0344057..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Details.cshtml +++ /dev/null @@ -1,39 +0,0 @@ -@using System.Text.Json -@model Bit.Core.Billing.Providers.Migration.Models.ProviderMigrationResult -@{ - ViewData["Title"] = "Results"; -} - -

Migrate Providers

-

Migration Details: @Model.ProviderName

-
-
Id
-
@Model.ProviderId
- -
Result
-
@Model.Result
-
-

Client Organizations

-
- - - - - - - - - - - @foreach (var clientResult in Model.Clients) - { - - - - - - - } - -
IDNameResultPrevious State
@clientResult.OrganizationId@clientResult.OrganizationName@clientResult.Result
@Html.Raw(JsonSerializer.Serialize(clientResult.PreviousState))
-
diff --git a/src/Admin/Billing/Views/MigrateProviders/Index.cshtml b/src/Admin/Billing/Views/MigrateProviders/Index.cshtml deleted file mode 100644 index 0aed94c25d..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Index.cshtml +++ /dev/null @@ -1,46 +0,0 @@ -@model Bit.Admin.Billing.Models.MigrateProvidersRequestModel; -@{ - ViewData["Title"] = "Migrate Providers"; -} - -

Migrate Providers

-

Bulk Consolidated Billing Migration Tool

-
-

- This tool allows you to provide a list of IDs for Providers that you would like to migrate to Consolidated Billing. - Because of the expensive nature of the operation, you can only migrate 10 Providers at a time. -

-

- Updates made through this tool are irreversible without manual intervention. -

-

Example Input (Please enter each Provider ID separated by a new line):

-
-
-
f513affc-2290-4336-879e-21ec3ecf3e78
-f7a5cb0d-4b74-445c-8d8c-232d1d32bbe2
-bf82d3cf-0e21-4f39-b81b-ef52b2fc6a3a
-174e82fc-70c3-448d-9fe7-00bad2a3ab00
-22a4bbbf-58e3-4e4c-a86a-a0d7caf4ff14
-
-
-
-
-
- - -
-
- -
-
-
-
-
- - -
-
- -
-
-
diff --git a/src/Admin/Billing/Views/MigrateProviders/Results.cshtml b/src/Admin/Billing/Views/MigrateProviders/Results.cshtml deleted file mode 100644 index 94db08db3d..0000000000 --- a/src/Admin/Billing/Views/MigrateProviders/Results.cshtml +++ /dev/null @@ -1,28 +0,0 @@ -@model Bit.Core.Billing.Providers.Migration.Models.ProviderMigrationResult[] -@{ - ViewData["Title"] = "Results"; -} - -

Migrate Providers

-

Results

-
- - - - - - - - - - @foreach (var result in Model) - { - - - - - - } - -
IDNameResult
@result.ProviderId@result.ProviderName@result.Result
-
diff --git a/src/Admin/Controllers/HomeController.cs b/src/Admin/Controllers/HomeController.cs index debe5979f5..5b36032ec9 100644 --- a/src/Admin/Controllers/HomeController.cs +++ b/src/Admin/Controllers/HomeController.cs @@ -61,7 +61,7 @@ public class HomeController : Controller } catch (HttpRequestException e) { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + _logger.LogError(e, "Error encountered while sending GET request to {RequestUri}", requestUri); return new JsonResult("Unable to fetch latest version") { StatusCode = StatusCodes.Status500InternalServerError }; } @@ -83,7 +83,7 @@ public class HomeController : Controller } catch (HttpRequestException e) { - _logger.LogError(e, $"Error encountered while sending GET request to {requestUri}"); + _logger.LogError(e, "Error encountered while sending GET request to {RequestUri}", requestUri); return new JsonResult("Unable to fetch installed version") { StatusCode = StatusCodes.Status500InternalServerError }; } diff --git a/src/Admin/Controllers/ToolsController.cs b/src/Admin/Controllers/ToolsController.cs index b754b1f968..46dafd65e7 100644 --- a/src/Admin/Controllers/ToolsController.cs +++ b/src/Admin/Controllers/ToolsController.cs @@ -1,7 +1,6 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using System.Text; using System.Text.Json; using Bit.Admin.Enums; using Bit.Admin.Models; @@ -10,7 +9,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Entities; -using Bit.Core.Models.BitStripe; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.Services; @@ -33,7 +31,6 @@ public class ToolsController : Controller private readonly IInstallationRepository _installationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IProviderUserRepository _providerUserRepository; - private readonly IPaymentService _paymentService; private readonly IStripeAdapter _stripeAdapter; private readonly IWebHostEnvironment _environment; @@ -46,7 +43,6 @@ public class ToolsController : Controller IInstallationRepository installationRepository, IOrganizationUserRepository organizationUserRepository, IProviderUserRepository providerUserRepository, - IPaymentService paymentService, IStripeAdapter stripeAdapter, IWebHostEnvironment environment) { @@ -58,7 +54,6 @@ public class ToolsController : Controller _installationRepository = installationRepository; _organizationUserRepository = organizationUserRepository; _providerUserRepository = providerUserRepository; - _paymentService = paymentService; _stripeAdapter = stripeAdapter; _environment = environment; } @@ -341,138 +336,4 @@ public class ToolsController : Controller throw new Exception("No license to generate."); } } - - [RequirePermission(Permission.Tools_ManageStripeSubscriptions)] - public async Task StripeSubscriptions(StripeSubscriptionListOptions options) - { - options = options ?? new StripeSubscriptionListOptions(); - options.Limit = 10; - options.Expand = new List() { "data.customer", "data.latest_invoice" }; - options.SelectAll = false; - - var subscriptions = await _stripeAdapter.SubscriptionListAsync(options); - - options.StartingAfter = subscriptions.LastOrDefault()?.Id; - options.EndingBefore = await StripeSubscriptionsGetHasPreviousPage(subscriptions, options) ? - subscriptions.FirstOrDefault()?.Id : - null; - - var isProduction = _environment.IsProduction(); - var model = new StripeSubscriptionsModel() - { - Items = subscriptions.Select(s => new StripeSubscriptionRowModel(s)).ToList(), - Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data, - TestClocks = isProduction ? new List() : await _stripeAdapter.TestClockListAsync(), - Filter = options - }; - return View(model); - } - - [HttpPost] - [RequirePermission(Permission.Tools_ManageStripeSubscriptions)] - public async Task StripeSubscriptions([FromForm] StripeSubscriptionsModel model) - { - if (!ModelState.IsValid) - { - var isProduction = _environment.IsProduction(); - model.Prices = (await _stripeAdapter.PriceListAsync(new Stripe.PriceListOptions() { Limit = 100 })).Data; - model.TestClocks = isProduction ? new List() : await _stripeAdapter.TestClockListAsync(); - return View(model); - } - - if (model.Action == StripeSubscriptionsAction.Export || model.Action == StripeSubscriptionsAction.BulkCancel) - { - var subscriptions = model.Filter.SelectAll ? - await _stripeAdapter.SubscriptionListAsync(model.Filter) : - model.Items.Where(x => x.Selected).Select(x => x.Subscription); - - if (model.Action == StripeSubscriptionsAction.Export) - { - return StripeSubscriptionsExport(subscriptions); - } - - if (model.Action == StripeSubscriptionsAction.BulkCancel) - { - await StripeSubscriptionsCancel(subscriptions); - } - } - else - { - if (model.Action == StripeSubscriptionsAction.PreviousPage || model.Action == StripeSubscriptionsAction.Search) - { - model.Filter.StartingAfter = null; - } - - if (model.Action == StripeSubscriptionsAction.NextPage || model.Action == StripeSubscriptionsAction.Search) - { - if (!string.IsNullOrEmpty(model.Filter.StartingAfter)) - { - var subscription = await _stripeAdapter.SubscriptionGetAsync(model.Filter.StartingAfter); - if (subscription.Status == "canceled") - { - model.Filter.StartingAfter = null; - } - } - model.Filter.EndingBefore = null; - } - } - - - return RedirectToAction("StripeSubscriptions", model.Filter); - } - - // This requires a redundant API call to Stripe because of the way they handle pagination. - // The StartingBefore value has to be inferred from the list we get, and isn't supplied by Stripe. - private async Task StripeSubscriptionsGetHasPreviousPage(List subscriptions, StripeSubscriptionListOptions options) - { - var hasPreviousPage = false; - if (subscriptions.FirstOrDefault()?.Id != null) - { - var previousPageSearchOptions = new StripeSubscriptionListOptions() - { - EndingBefore = subscriptions.FirstOrDefault().Id, - Limit = 1, - Status = options.Status, - CurrentPeriodEndDate = options.CurrentPeriodEndDate, - CurrentPeriodEndRange = options.CurrentPeriodEndRange, - Price = options.Price - }; - hasPreviousPage = (await _stripeAdapter.SubscriptionListAsync(previousPageSearchOptions)).Count > 0; - } - return hasPreviousPage; - } - - private async Task StripeSubscriptionsCancel(IEnumerable subscriptions) - { - foreach (var s in subscriptions) - { - await _stripeAdapter.SubscriptionCancelAsync(s.Id); - if (s.LatestInvoice?.Status == "open") - { - await _stripeAdapter.InvoiceVoidInvoiceAsync(s.LatestInvoiceId); - } - } - } - - private FileResult StripeSubscriptionsExport(IEnumerable subscriptions) - { - var fieldsToExport = subscriptions.Select(s => new - { - StripeId = s.Id, - CustomerEmail = s.Customer?.Email, - SubscriptionStatus = s.Status, - InvoiceDueDate = s.CurrentPeriodEnd, - SubscriptionProducts = s.Items?.Data.Select(p => p.Plan.Id) - }); - - var options = new JsonSerializerOptions - { - PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - WriteIndented = true - }; - - var result = System.Text.Json.JsonSerializer.Serialize(fieldsToExport, options); - var bytes = Encoding.UTF8.GetBytes(result); - return File(bytes, "application/json", "StripeSubscriptionsSearch.json"); - } } diff --git a/src/Admin/Enums/Permissions.cs b/src/Admin/Enums/Permissions.cs index 14b255b2b6..34d975226e 100644 --- a/src/Admin/Enums/Permissions.cs +++ b/src/Admin/Enums/Permissions.cs @@ -52,8 +52,6 @@ public enum Permission Tools_PromoteProviderServiceUser, Tools_GenerateLicenseFile, Tools_ManageTaxRates, - Tools_ManageStripeSubscriptions, Tools_CreateEditTransaction, - Tools_ProcessStripeEvents, - Tools_MigrateProviders + Tools_ProcessStripeEvents } diff --git a/src/Admin/Jobs/AliveJob.cs b/src/Admin/Jobs/AliveJob.cs index b97d597e58..d62f4cc2cc 100644 --- a/src/Admin/Jobs/AliveJob.cs +++ b/src/Admin/Jobs/AliveJob.cs @@ -22,7 +22,7 @@ public class AliveJob : BaseJob { _logger.LogInformation(Constants.BypassFiltersEventId, "Execute job task: Keep alive"); var response = await _httpClient.GetAsync(_globalSettings.BaseServiceUri.Admin); - _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, " + + _logger.LogInformation(Constants.BypassFiltersEventId, "Finished job task: Keep alive, {StatusCode}", response.StatusCode); } } diff --git a/src/Admin/Models/StripeSubscriptionsModel.cs b/src/Admin/Models/StripeSubscriptionsModel.cs deleted file mode 100644 index 36e1f099e1..0000000000 --- a/src/Admin/Models/StripeSubscriptionsModel.cs +++ /dev/null @@ -1,45 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using Bit.Core.Models.BitStripe; - -namespace Bit.Admin.Models; - -public class StripeSubscriptionRowModel -{ - public Stripe.Subscription Subscription { get; set; } - public bool Selected { get; set; } - - public StripeSubscriptionRowModel() { } - public StripeSubscriptionRowModel(Stripe.Subscription subscription) - { - Subscription = subscription; - } -} - -public enum StripeSubscriptionsAction -{ - Search, - PreviousPage, - NextPage, - Export, - BulkCancel -} - -public class StripeSubscriptionsModel : IValidatableObject -{ - public List Items { get; set; } - public StripeSubscriptionsAction Action { get; set; } = StripeSubscriptionsAction.Search; - public string Message { get; set; } - public List Prices { get; set; } - public List TestClocks { get; set; } - public StripeSubscriptionListOptions Filter { get; set; } = new StripeSubscriptionListOptions(); - public IEnumerable Validate(ValidationContext validationContext) - { - if (Action == StripeSubscriptionsAction.BulkCancel && Filter.Status != "unpaid") - { - yield return new ValidationResult("Bulk cancel is currently only supported for unpaid subscriptions"); - } - } -} diff --git a/src/Admin/Startup.cs b/src/Admin/Startup.cs index 5b34e13f6c..5ecbdc899c 100644 --- a/src/Admin/Startup.cs +++ b/src/Admin/Startup.cs @@ -10,7 +10,6 @@ using Microsoft.AspNetCore.Mvc.Razor; using Microsoft.Extensions.DependencyInjection.Extensions; using Bit.Admin.Services; using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Providers.Migration; #if !OSS using Bit.Commercial.Core.Utilities; @@ -92,7 +91,6 @@ public class Startup services.AddDistributedCache(globalSettings); services.AddBillingOperations(); services.AddHttpClient(); - services.AddProviderMigration(); #if OSS services.AddOosServices(); diff --git a/src/Admin/Utilities/RolePermissionMapping.cs b/src/Admin/Utilities/RolePermissionMapping.cs index b60cf895a1..6dddc4ffeb 100644 --- a/src/Admin/Utilities/RolePermissionMapping.cs +++ b/src/Admin/Utilities/RolePermissionMapping.cs @@ -52,8 +52,7 @@ public static class RolePermissionMapping Permission.Tools_PromoteAdmin, Permission.Tools_PromoteProviderServiceUser, Permission.Tools_GenerateLicenseFile, - Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions + Permission.Tools_ManageTaxRates } }, { "admin", new List @@ -105,7 +104,6 @@ public static class RolePermissionMapping Permission.Tools_PromoteProviderServiceUser, Permission.Tools_GenerateLicenseFile, Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions, Permission.Tools_CreateEditTransaction } }, @@ -180,10 +178,8 @@ public static class RolePermissionMapping Permission.Tools_ChargeBrainTreeCustomer, Permission.Tools_GenerateLicenseFile, Permission.Tools_ManageTaxRates, - Permission.Tools_ManageStripeSubscriptions, Permission.Tools_CreateEditTransaction, - Permission.Tools_ProcessStripeEvents, - Permission.Tools_MigrateProviders + Permission.Tools_ProcessStripeEvents } }, { "sales", new List diff --git a/src/Admin/Views/Shared/_Layout.cshtml b/src/Admin/Views/Shared/_Layout.cshtml index 1661a8bbc3..c13be428b4 100644 --- a/src/Admin/Views/Shared/_Layout.cshtml +++ b/src/Admin/Views/Shared/_Layout.cshtml @@ -13,12 +13,10 @@ var canPromoteAdmin = AccessControlService.UserHasPermission(Permission.Tools_PromoteAdmin); var canPromoteProviderServiceUser = AccessControlService.UserHasPermission(Permission.Tools_PromoteProviderServiceUser); var canGenerateLicense = AccessControlService.UserHasPermission(Permission.Tools_GenerateLicenseFile); - var canManageStripeSubscriptions = AccessControlService.UserHasPermission(Permission.Tools_ManageStripeSubscriptions); var canProcessStripeEvents = AccessControlService.UserHasPermission(Permission.Tools_ProcessStripeEvents); - var canMigrateProviders = AccessControlService.UserHasPermission(Permission.Tools_MigrateProviders); var canViewTools = canChargeBraintree || canCreateTransaction || canPromoteAdmin || canPromoteProviderServiceUser || - canGenerateLicense || canManageStripeSubscriptions; + canGenerateLicense; } @@ -102,12 +100,6 @@ Generate License - } - @if (canManageStripeSubscriptions) - { - - Manage Stripe Subscriptions - } @if (canProcessStripeEvents) { @@ -115,12 +107,6 @@ Process Stripe Events } - @if (canMigrateProviders) - { - - Migrate Providers - - } } diff --git a/src/Admin/Views/Tools/StripeSubscriptions.cshtml b/src/Admin/Views/Tools/StripeSubscriptions.cshtml deleted file mode 100644 index d8c168b3b0..0000000000 --- a/src/Admin/Views/Tools/StripeSubscriptions.cshtml +++ /dev/null @@ -1,277 +0,0 @@ -@model StripeSubscriptionsModel -@{ - ViewData["Title"] = "Stripe Subscriptions"; -} - -@section Scripts { - -} - -

Manage Stripe Subscriptions

-@if (!string.IsNullOrWhiteSpace(Model.Message)) -{ -
-} -
-
-
-
- - -
-
- -
-
-
- - -
-
- - -
-
- @{ - var date = @Model.Filter.CurrentPeriodEndDate.HasValue ? @Model.Filter.CurrentPeriodEndDate.Value.ToString("yyyy-MM-dd") : string.Empty; - } - -
-
-
- - -
-
- - -
-
- -
-
-
- -
-
- All @Model.Items.Count subscriptions on this page are selected.
- - - All subscriptions for this search are selected. - - -
-
-
- - - - - - - - - - - - - @if (!Model.Items.Any()) - { - - - - } - else - { - @for (var i = 0; i < Model.Items.Count; i++) - { - - - - - - - - - } - } - -
-
- -
-
IdCustomer EmailStatusProduct TierCurrent Period End
No results to list.
- - @{ - var i0 = i; - } - - - - - - - - @for (var j = 0; j < Model.Items[i].Subscription.Items.Data.Count; j++) - { - var i1 = i; - var j1 = j; - - } -
- - @{ - var i2 = i; - } - -
-
- @Model.Items[i].Subscription.Id - - @Model.Items[i].Subscription.Customer?.Email - - @Model.Items[i].Subscription.Status - - @string.Join(",", Model.Items[i].Subscription.Items.Data.Select(product => product.Plan.Id).ToArray()) - - @Model.Items[i].Subscription.CurrentPeriodEnd.ToShortDateString() -
-
- -
diff --git a/src/Admin/package-lock.json b/src/Admin/package-lock.json index 2e3a335598..6e0f78e1e6 100644 --- a/src/Admin/package-lock.json +++ b/src/Admin/package-lock.json @@ -18,9 +18,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } }, @@ -679,6 +679,7 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -705,6 +706,7 @@ "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -747,6 +749,16 @@ "ajv": "^8.8.2" } }, + "node_modules/baseline-browser-mapping": { + "version": "2.8.18", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.18.tgz", + "integrity": "sha512-UYmTpOBwgPScZpS4A+YbapwWuBwasxvO/2IOHArSsAhL/+ZdmATBXTex3t+l2hXwLVYK382ibr/nKoY9GKe86w==", + "dev": true, + "license": "Apache-2.0", + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/bootstrap": { "version": "5.3.6", "resolved": "https://registry.npmjs.org/bootstrap/-/bootstrap-5.3.6.tgz", @@ -781,9 +793,9 @@ } }, "node_modules/browserslist": { - "version": "4.25.4", - "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.25.4.tgz", - "integrity": "sha512-4jYpcjabC606xJ3kw2QwGEZKX0Aw7sgQdZCvIK9dhVSPh76BKo+C+btT1RRofH7B+8iNpEbgGNVWiLki5q93yg==", + "version": "4.26.3", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.26.3.tgz", + "integrity": "sha512-lAUU+02RFBuCKQPj/P6NgjlbCnLBMp4UtgTx7vNHd3XSIJF87s9a5rA3aH2yw3GS9DqZAUbOtZdCCiZeVRqt0w==", "dev": true, "funding": [ { @@ -800,10 +812,12 @@ } ], "license": "MIT", + "peer": true, "dependencies": { - "caniuse-lite": "^1.0.30001737", - "electron-to-chromium": "^1.5.211", - "node-releases": "^2.0.19", + "baseline-browser-mapping": "^2.8.9", + "caniuse-lite": "^1.0.30001746", + "electron-to-chromium": "^1.5.227", + "node-releases": "^2.0.21", "update-browserslist-db": "^1.1.3" }, "bin": { @@ -821,9 +835,9 @@ "license": "MIT" }, "node_modules/caniuse-lite": { - "version": "1.0.30001741", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001741.tgz", - "integrity": "sha512-QGUGitqsc8ARjLdgAfxETDhRbJ0REsP6O3I96TAth/mVjh2cYzN2u+3AzPP3aVSm2FehEItaJw1xd+IGBXWeSw==", + "version": "1.0.30001751", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001751.tgz", + "integrity": "sha512-A0QJhug0Ly64Ii3eIqHu5X51ebln3k4yTUkY1j8drqpWHVreg/VLijN48cZ1bYPiqOQuqpkIKnzr/Ul8V+p6Cw==", "dev": true, "funding": [ { @@ -975,9 +989,9 @@ } }, "node_modules/electron-to-chromium": { - "version": "1.5.215", - "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.215.tgz", - "integrity": "sha512-TIvGp57UpeNetj/wV/xpFNpWGb0b/ROw372lHPx5Aafx02gjTBtWnEEcaSX3W2dLM3OSdGGyHX/cHl01JQsLaQ==", + "version": "1.5.237", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.237.tgz", + "integrity": "sha512-icUt1NvfhGLar5lSWH3tHNzablaA5js3HVHacQimfP8ViEBOQv+L7DKEuHdbTZ0SKCO1ogTJTIL1Gwk9S6Qvcg==", "dev": true, "license": "ISC" }, @@ -1528,9 +1542,9 @@ "optional": true }, "node_modules/node-releases": { - "version": "2.0.20", - "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.20.tgz", - "integrity": "sha512-7gK6zSXEH6neM212JgfYFXe+GmZQM+fia5SsusuBIUgnPheLFBmIPhtFoAQRj8/7wASYQnbDlHPVwY0BefoFgA==", + "version": "2.0.26", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.26.tgz", + "integrity": "sha512-S2M9YimhSjBSvYnlr5/+umAnPHE++ODwt5e2Ij6FoX45HA/s4vHdkDx1eax2pAPeAOqu4s9b7ppahsyEFdVqQA==", "dev": true, "license": "MIT" }, @@ -1654,6 +1668,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -1860,11 +1875,12 @@ "license": "MIT" }, "node_modules/sass": { - "version": "1.91.0", - "resolved": "https://registry.npmjs.org/sass/-/sass-1.91.0.tgz", - "integrity": "sha512-aFOZHGf+ur+bp1bCHZ+u8otKGh77ZtmFyXDo4tlYvT7PWql41Kwd8wdkPqhhT+h2879IVblcHFglIMofsFd1EA==", + "version": "1.93.2", + "resolved": "https://registry.npmjs.org/sass/-/sass-1.93.2.tgz", + "integrity": "sha512-t+YPtOQHpGW1QWsh1CHQ5cPIr9lbbGZLZnbihP/D/qZj/yuV68m8qarcV17nvkOX81BCrvzAlq2klCQFZghyTg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "chokidar": "^4.0.0", "immutable": "^5.0.2", @@ -1922,9 +1938,9 @@ } }, "node_modules/schema-utils": { - "version": "4.3.2", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", - "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", + "version": "4.3.3", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.3.tgz", + "integrity": "sha512-eflK8wEtyOE6+hsaRVPxvUKYCpRgzLqDTb8krvAsRIwOGlHoSgYLgBXoubGgLd2fT41/OUYdb48v4k4WWHQurA==", "dev": true, "license": "MIT", "dependencies": { @@ -2061,9 +2077,9 @@ } }, "node_modules/tapable": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.2.3.tgz", - "integrity": "sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/tapable/-/tapable-2.3.0.tgz", + "integrity": "sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==", "dev": true, "license": "MIT", "engines": { @@ -2210,11 +2226,12 @@ } }, "node_modules/webpack": { - "version": "5.101.3", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.101.3.tgz", - "integrity": "sha512-7b0dTKR3Ed//AD/6kkx/o7duS8H3f1a4w3BYpIriX4BzIhjkn4teo05cptsxvLesHFKK5KObnadmCHBwGc+51A==", + "version": "5.102.1", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.102.1.tgz", + "integrity": "sha512-7h/weGm9d/ywQ6qzJ+Xy+r9n/3qgp/thalBbpOi5i223dPXKi04IBtqPN9nTd+jBc7QKfvDbaBnFipYp4sJAUQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@types/eslint-scope": "^3.7.7", "@types/estree": "^1.0.8", @@ -2224,7 +2241,7 @@ "@webassemblyjs/wasm-parser": "^1.14.1", "acorn": "^8.15.0", "acorn-import-phases": "^1.0.3", - "browserslist": "^4.24.0", + "browserslist": "^4.26.3", "chrome-trace-event": "^1.0.2", "enhanced-resolve": "^5.17.3", "es-module-lexer": "^1.2.1", @@ -2236,10 +2253,10 @@ "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", - "schema-utils": "^4.3.2", - "tapable": "^2.1.1", + "schema-utils": "^4.3.3", + "tapable": "^2.3.0", "terser-webpack-plugin": "^5.3.11", - "watchpack": "^2.4.1", + "watchpack": "^2.4.4", "webpack-sources": "^3.3.3" }, "bin": { @@ -2264,6 +2281,7 @@ "integrity": "sha512-pIDJHIEI9LR0yxHXQ+Qh95k2EvXpWzZ5l+d+jIo+RdSm9MiHfzazIxwwni/p7+x4eJZuvG1AJwgC4TNQ7NRgsg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@discoveryjs/json-ext": "^0.5.0", "@webpack-cli/configtest": "^2.1.1", diff --git a/src/Admin/package.json b/src/Admin/package.json index 89ee1c5358..f6f21e2cf9 100644 --- a/src/Admin/package.json +++ b/src/Admin/package.json @@ -17,9 +17,9 @@ "css-loader": "7.1.2", "expose-loader": "5.0.1", "mini-css-extract-plugin": "2.9.2", - "sass": "1.91.0", + "sass": "1.93.2", "sass-loader": "16.0.5", - "webpack": "5.101.3", + "webpack": "5.102.1", "webpack-cli": "5.1.4" } } diff --git a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs index ed628105e0..a3234f61d7 100644 --- a/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs +++ b/src/Api/AdminConsole/Authorization/AuthorizationHandlerCollectionExtensions.cs @@ -12,10 +12,11 @@ public static class AuthorizationHandlerCollectionExtensions services.TryAddScoped(); services.TryAddEnumerable([ - ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), ServiceDescriptor.Scoped(), + ServiceDescriptor.Scoped(), ]); } } diff --git a/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs new file mode 100644 index 0000000000..239148ab25 --- /dev/null +++ b/src/Api/AdminConsole/Authorization/RecoverAccountAuthorizationHandler.cs @@ -0,0 +1,110 @@ +using System.Security.Claims; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Context; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Microsoft.AspNetCore.Authorization; + +namespace Bit.Api.AdminConsole.Authorization; + +/// +/// An authorization requirement for recovering an organization member's account. +/// +/// +/// Note: this is different to simply being able to manage account recovery. The user must be recovering +/// a member who has equal or lesser permissions than them. +/// +public class RecoverAccountAuthorizationRequirement : IAuthorizationRequirement; + +/// +/// Authorizes members and providers to recover a target OrganizationUser's account. +/// +/// +/// This prevents privilege escalation by ensuring that a user cannot recover the account of +/// another user with a higher role or with provider membership. +/// +public class RecoverAccountAuthorizationHandler( + IOrganizationContext organizationContext, + ICurrentContext currentContext, + IProviderUserRepository providerUserRepository) + : AuthorizationHandler +{ + public const string FailureReason = "You are not permitted to recover this user's account."; + public const string ProviderFailureReason = "You are not permitted to recover a Provider member's account."; + + protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, + RecoverAccountAuthorizationRequirement requirement, + OrganizationUser targetOrganizationUser) + { + // Step 1: check that the User has permissions with respect to the organization. + // This may come from their role in the organization or their provider relationship. + var canRecoverOrganizationMember = + AuthorizeMember(context.User, targetOrganizationUser) || + await AuthorizeProviderAsync(context.User, targetOrganizationUser); + + if (!canRecoverOrganizationMember) + { + context.Fail(new AuthorizationFailureReason(this, FailureReason)); + return; + } + + // Step 2: check that the User has permissions with respect to any provider the target user is a member of. + // This prevents an organization admin performing privilege escalation into an unrelated provider. + var canRecoverProviderMember = await CanRecoverProviderAsync(targetOrganizationUser); + if (!canRecoverProviderMember) + { + context.Fail(new AuthorizationFailureReason(this, ProviderFailureReason)); + return; + } + + context.Succeed(requirement); + } + + private async Task AuthorizeProviderAsync(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + return await organizationContext.IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId); + } + + private bool AuthorizeMember(ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = organizationContext.GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId); + if (currentContextOrganization == null) + { + return false; + } + + // Current user must have equal or greater permissions than the user account being recovered + var authorized = targetOrganizationUser.Type switch + { + OrganizationUserType.Owner => currentContextOrganization.Type is OrganizationUserType.Owner, + OrganizationUserType.Admin => currentContextOrganization.Type is OrganizationUserType.Owner or OrganizationUserType.Admin, + _ => currentContextOrganization is + { Type: OrganizationUserType.Owner or OrganizationUserType.Admin } + or { Type: OrganizationUserType.Custom, Permissions.ManageResetPassword: true } + }; + + return authorized; + } + + private async Task CanRecoverProviderAsync(OrganizationUser targetOrganizationUser) + { + if (!targetOrganizationUser.UserId.HasValue) + { + // If an OrganizationUser is not linked to a User then it can't be linked to a Provider either. + // This is invalid but does not pose a privilege escalation risk. Return early and let the command + // handle the invalid input. + return true; + } + + var targetUserProviderUsers = + await providerUserRepository.GetManyByUserAsync(targetOrganizationUser.UserId.Value); + + // If the target user belongs to any provider that the current user is not a member of, + // deny the action to prevent privilege escalation from organization to provider. + // Note: we do not expect that a user is a member of more than 1 provider, but there is also no guarantee + // against it; this returns a sequence, so we handle the possibility. + var authorized = targetUserProviderUsers.All(providerUser => currentContext.ProviderUser(providerUser.ProviderId)); + return authorized; + } +} + diff --git a/src/Api/AdminConsole/Controllers/EventsController.cs b/src/Api/AdminConsole/Controllers/EventsController.cs index 18199ad8f2..7e058a7870 100644 --- a/src/Api/AdminConsole/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Controllers/EventsController.cs @@ -3,6 +3,7 @@ using Bit.Api.Models.Response; using Bit.Api.Utilities; +using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Context; using Bit.Core.Enums; @@ -30,16 +31,22 @@ public class EventsController : Controller private readonly ICurrentContext _currentContext; private readonly ISecretRepository _secretRepository; private readonly IProjectRepository _projectRepository; + private readonly IServiceAccountRepository _serviceAccountRepository; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; - public EventsController( - IUserService userService, + + public EventsController(IUserService userService, ICipherRepository cipherRepository, IOrganizationUserRepository organizationUserRepository, IProviderUserRepository providerUserRepository, IEventRepository eventRepository, ICurrentContext currentContext, ISecretRepository secretRepository, - IProjectRepository projectRepository) + IProjectRepository projectRepository, + IServiceAccountRepository serviceAccountRepository, + ILogger logger, + IFeatureService featureService) { _userService = userService; _cipherRepository = cipherRepository; @@ -49,6 +56,9 @@ public class EventsController : Controller _currentContext = currentContext; _secretRepository = secretRepository; _projectRepository = projectRepository; + _serviceAccountRepository = serviceAccountRepository; + _logger = logger; + _featureService = featureService; } [HttpGet("")] @@ -110,6 +120,9 @@ public class EventsController : Controller var result = await _eventRepository.GetManyByOrganizationAsync(orgId, dateRange.Item1, dateRange.Item2, new PageOptions { ContinuationToken = continuationToken }); var responses = result.Data.Select(e => new EventResponseModel(e)); + + _logger.LogAggregateData(_featureService, orgId, responses, continuationToken, start, end); + return new ListResponseModel(responses, result.ContinuationToken); } @@ -184,6 +197,57 @@ public class EventsController : Controller return new ListResponseModel(responses, result.ContinuationToken); } + [HttpGet("~/organization/{orgId}/service-account/{id}/events")] + public async Task> GetServiceAccounts( + Guid orgId, + Guid id, + [FromQuery] DateTime? start = null, + [FromQuery] DateTime? end = null, + [FromQuery] string continuationToken = null) + { + if (id == Guid.Empty || orgId == Guid.Empty) + { + throw new NotFoundException(); + } + + var serviceAccount = await GetServiceAccount(id, orgId); + var org = _currentContext.GetOrganization(orgId); + + if (org == null || !await _currentContext.AccessEventLogs(org.Id)) + { + throw new NotFoundException(); + } + + var (fromDate, toDate) = ApiHelpers.GetDateRange(start, end); + var result = await _eventRepository.GetManyByOrganizationServiceAccountAsync( + serviceAccount.OrganizationId, + serviceAccount.Id, + fromDate, + toDate, + new PageOptions { ContinuationToken = continuationToken }); + + var responses = result.Data.Select(e => new EventResponseModel(e)); + return new ListResponseModel(responses, result.ContinuationToken); + } + + [ApiExplorerSettings(IgnoreApi = true)] + private async Task GetServiceAccount(Guid serviceAccountId, Guid orgId) + { + var serviceAccount = await _serviceAccountRepository.GetByIdAsync(serviceAccountId); + if (serviceAccount != null) + { + return serviceAccount; + } + + var fallbackServiceAccount = new ServiceAccount + { + Id = serviceAccountId, + OrganizationId = orgId + }; + + return fallbackServiceAccount; + } + [HttpGet("~/organizations/{orgId}/users/{id}/events")] public async Task> GetOrganizationUser(string orgId, string id, [FromQuery] DateTime? start = null, [FromQuery] DateTime? end = null, [FromQuery] string continuationToken = null) diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs index ae0f91d355..0b7fe8dffe 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationIntegrationConfigurationController.cs @@ -1,16 +1,13 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; namespace Bit.Api.AdminConsole.Controllers; -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] [Route("organizations/{organizationId:guid}/integrations/{integrationId:guid}/configurations")] [Authorize("Application")] public class OrganizationIntegrationConfigurationController( diff --git a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs b/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs index a12492949d..181811e892 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationIntegrationController.cs @@ -1,18 +1,13 @@ using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -#nullable enable - namespace Bit.Api.AdminConsole.Controllers; -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] [Route("organizations/{organizationId:guid}/integrations")] [Authorize("Application")] public class OrganizationIntegrationController( diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 74ac9b1255..4b9f7e5d71 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below +// NOTE: This file is partially migrated to nullable reference types. Remove inline #nullable directives when addressing the FIXME. #nullable disable using Bit.Api.AdminConsole.Authorization; @@ -11,6 +12,7 @@ using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimedAccount; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; @@ -70,6 +72,7 @@ public class OrganizationUsersController : Controller private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; private readonly IInitPendingOrganizationCommand _initPendingOrganizationCommand; private readonly IRevokeOrganizationUserCommand _revokeOrganizationUserCommand; + private readonly IAdminRecoverAccountCommand _adminRecoverAccountCommand; public OrganizationUsersController(IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, @@ -97,7 +100,8 @@ public class OrganizationUsersController : Controller IRestoreOrganizationUserCommand restoreOrganizationUserCommand, IInitPendingOrganizationCommand initPendingOrganizationCommand, IRevokeOrganizationUserCommand revokeOrganizationUserCommand, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IAdminRecoverAccountCommand adminRecoverAccountCommand) { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; @@ -126,6 +130,7 @@ public class OrganizationUsersController : Controller _restoreOrganizationUserCommand = restoreOrganizationUserCommand; _initPendingOrganizationCommand = initPendingOrganizationCommand; _revokeOrganizationUserCommand = revokeOrganizationUserCommand; + _adminRecoverAccountCommand = adminRecoverAccountCommand; } [HttpGet("{id}")] @@ -474,21 +479,27 @@ public class OrganizationUsersController : Controller [HttpPut("{id}/reset-password")] [Authorize] - public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) + public async Task PutResetPassword(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) { + if (_featureService.IsEnabled(FeatureFlagKeys.AccountRecoveryCommand)) + { + // TODO: remove legacy implementation after feature flag is enabled. + return await PutResetPasswordNew(orgId, id, model); + } + // Get the users role, since provider users aren't a member of the organization we use the owner check var orgUserType = await _currentContext.OrganizationOwner(orgId) ? OrganizationUserType.Owner : _currentContext.Organizations?.FirstOrDefault(o => o.Id == orgId)?.Type; if (orgUserType == null) { - throw new NotFoundException(); + return TypedResults.NotFound(); } var result = await _userService.AdminResetPasswordAsync(orgUserType.Value, orgId, id, model.NewMasterPasswordHash, model.Key); if (result.Succeeded) { - return; + return TypedResults.Ok(); } foreach (var error in result.Errors) @@ -497,9 +508,45 @@ public class OrganizationUsersController : Controller } await Task.Delay(2000); - throw new BadRequestException(ModelState); + return TypedResults.BadRequest(ModelState); } +#nullable enable + // TODO: make sure the route and authorize attributes are maintained when the legacy implementation is removed. + private async Task PutResetPasswordNew(Guid orgId, Guid id, [FromBody] OrganizationUserResetPasswordRequestModel model) + { + var targetOrganizationUser = await _organizationUserRepository.GetByIdAsync(id); + if (targetOrganizationUser == null || targetOrganizationUser.OrganizationId != orgId) + { + return TypedResults.NotFound(); + } + + var authorizationResult = await _authorizationService.AuthorizeAsync(User, targetOrganizationUser, new RecoverAccountAuthorizationRequirement()); + if (!authorizationResult.Succeeded) + { + // Return an informative error to show in the UI. + // The Authorize attribute already prevents enumeration by users outside the organization, so this can be specific. + var failureReason = authorizationResult.Failure?.FailureReasons.FirstOrDefault()?.Message ?? RecoverAccountAuthorizationHandler.FailureReason; + // This should be a 403 Forbidden, but that causes a logout on our client apps so we're using 400 Bad Request instead + return TypedResults.BadRequest(new ErrorResponseModel(failureReason)); + } + + var result = await _adminRecoverAccountCommand.RecoverAccountAsync(orgId, targetOrganizationUser, model.NewMasterPasswordHash, model.Key); + if (result.Succeeded) + { + return TypedResults.Ok(); + } + + foreach (var error in result.Errors) + { + ModelState.AddModelError(string.Empty, error.Description); + } + + await Task.Delay(2000); + return TypedResults.BadRequest(ModelState); + } +#nullable disable + [HttpDelete("{id}")] [Authorize] public async Task Remove(Guid orgId, Guid id) diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index ce92321833..a5272413e2 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -12,6 +12,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Context; @@ -41,8 +42,9 @@ public class PoliciesController : Controller private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IPolicyRepository _policyRepository; private readonly IUserService _userService; - + private readonly IFeatureService _featureService; private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public PoliciesController(IPolicyRepository policyRepository, IOrganizationUserRepository organizationUserRepository, @@ -53,7 +55,9 @@ public class PoliciesController : Controller IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IOrganizationRepository organizationRepository, - ISavePolicyCommand savePolicyCommand) + IFeatureService featureService, + ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand) { _policyRepository = policyRepository; _organizationUserRepository = organizationUserRepository; @@ -65,7 +69,9 @@ public class PoliciesController : Controller _organizationRepository = organizationRepository; _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _featureService = featureService; _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } [HttpGet("{type}")] @@ -203,27 +209,22 @@ public class PoliciesController : Controller throw new NotFoundException(); } - if (type != model.Type) - { - throw new BadRequestException("Mismatched policy type"); - } - - var policyUpdate = await model.ToPolicyUpdateAsync(orgId, _currentContext); + var policyUpdate = await model.ToPolicyUpdateAsync(orgId, type, _currentContext); var policy = await _savePolicyCommand.SaveAsync(policyUpdate); return new PolicyResponseModel(policy); } - [HttpPut("{type}/vnext")] [RequireFeatureAttribute(FeatureFlagKeys.CreateDefaultLocation)] [Authorize] - public async Task PutVNext(Guid orgId, [FromBody] SavePolicyRequest model) + public async Task PutVNext(Guid orgId, PolicyType type, [FromBody] SavePolicyRequest model) { - var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, _currentContext); + var savePolicyRequest = await model.ToSavePolicyModelAsync(orgId, type, _currentContext); - var policy = await _savePolicyCommand.VNextSaveAsync(savePolicyRequest); + var policy = _featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) ? + await _vNextSavePolicyCommand.SaveAsync(savePolicyRequest) : + await _savePolicyCommand.VNextSaveAsync(savePolicyRequest); return new PolicyResponseModel(policy); } - } diff --git a/src/Api/AdminConsole/Controllers/ProvidersController.cs b/src/Api/AdminConsole/Controllers/ProvidersController.cs index a1815fd3bf..aa87bf9c74 100644 --- a/src/Api/AdminConsole/Controllers/ProvidersController.cs +++ b/src/Api/AdminConsole/Controllers/ProvidersController.cs @@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Context; using Bit.Core.Exceptions; -using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; using Microsoft.AspNetCore.Authorization; @@ -93,22 +92,12 @@ public class ProvidersController : Controller var userId = _userService.GetProperUserId(User).Value; - var taxInfo = new TaxInfo - { - BillingAddressCountry = model.TaxInfo.Country, - BillingAddressPostalCode = model.TaxInfo.PostalCode, - TaxIdNumber = model.TaxInfo.TaxId, - BillingAddressLine1 = model.TaxInfo.Line1, - BillingAddressLine2 = model.TaxInfo.Line2, - BillingAddressCity = model.TaxInfo.City, - BillingAddressState = model.TaxInfo.State - }; - - var tokenizedPaymentSource = model.PaymentSource?.ToDomain(); + var paymentMethod = model.PaymentMethod.ToDomain(); + var billingAddress = model.BillingAddress.ToDomain(); var response = await _providerService.CompleteSetupAsync(model.ToProvider(provider), userId, model.Token, model.Key, - taxInfo, tokenizedPaymentSource); + paymentMethod, billingAddress); return new ProviderResponseModel(response); } diff --git a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs b/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs index 6e3751c6f6..7b53f73f81 100644 --- a/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs +++ b/src/Api/AdminConsole/Controllers/SlackIntegrationController.cs @@ -1,9 +1,5 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; +using System.Text.Json; using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.Context; @@ -11,32 +7,63 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; namespace Bit.Api.AdminConsole.Controllers; -[RequireFeature(FeatureFlagKeys.EventBasedOrganizationIntegrations)] -[Route("organizations/{organizationId:guid}/integrations/slack")] +[Route("organizations")] [Authorize("Application")] public class SlackIntegrationController( ICurrentContext currentContext, IOrganizationIntegrationRepository integrationRepository, - ISlackService slackService) : Controller + ISlackService slackService, + TimeProvider timeProvider) : Controller { - [HttpGet("redirect")] + [HttpGet("{organizationId:guid}/integrations/slack/redirect")] public async Task RedirectAsync(Guid organizationId) { if (!await currentContext.OrganizationOwner(organizationId)) { throw new NotFoundException(); } - string callbackUrl = Url.RouteUrl( - nameof(CreateAsync), - new { organizationId }, - currentContext.HttpContext.Request.Scheme); - var redirectUrl = slackService.GetRedirectUrl(callbackUrl); + + string? callbackUrl = Url.RouteUrl( + routeName: "SlackIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integration = integrations.FirstOrDefault(i => i.Type == IntegrationType.Slack); + + if (integration is null) + { + // No slack integration exists, create Initiated version + integration = await integrationRepository.CreateAsync(new OrganizationIntegration + { + OrganizationId = organizationId, + Type = IntegrationType.Slack, + Configuration = null, + }); + } + else if (integration.Configuration is not null) + { + // A Completed (fully configured) Slack integration already exists, throw to prevent overriding + throw new BadRequestException("There already exists a Slack integration for this organization"); + + } // An Initiated slack integration exits, re-use it and kick off a new OAuth flow + + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + var redirectUrl = slackService.GetRedirectUrl( + callbackUrl: callbackUrl, + state: state.ToString() + ); if (string.IsNullOrEmpty(redirectUrl)) { @@ -46,23 +73,42 @@ public class SlackIntegrationController( return Redirect(redirectUrl); } - [HttpGet("create", Name = nameof(CreateAsync))] - public async Task CreateAsync(Guid organizationId, [FromQuery] string code) + [HttpGet("integrations/slack/create", Name = "SlackIntegration_Create")] + [AllowAnonymous] + public async Task CreateAsync([FromQuery] string code, [FromQuery] string state) { - if (!await currentContext.OrganizationOwner(organizationId)) + var oAuthState = IntegrationOAuthState.FromString(state: state, timeProvider: timeProvider); + if (oAuthState is null) { throw new NotFoundException(); } - if (string.IsNullOrEmpty(code)) + // Fetch existing Initiated record + var integration = await integrationRepository.GetByIdAsync(oAuthState.IntegrationId); + if (integration is null || + integration.Type != IntegrationType.Slack || + integration.Configuration is not null) { - throw new BadRequestException("Missing code from Slack."); + throw new NotFoundException(); } - string callbackUrl = Url.RouteUrl( - nameof(CreateAsync), - new { organizationId }, - currentContext.HttpContext.Request.Scheme); + // Verify Organization matches hash + if (!oAuthState.ValidateOrg(integration.OrganizationId)) + { + throw new NotFoundException(); + } + + // Fetch token from Slack and store to DB + string? callbackUrl = Url.RouteUrl( + routeName: "SlackIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } var token = await slackService.ObtainTokenViaOAuth(code, callbackUrl); if (string.IsNullOrEmpty(token)) @@ -70,14 +116,10 @@ public class SlackIntegrationController( throw new BadRequestException("Invalid response from Slack."); } - var integration = await integrationRepository.CreateAsync(new OrganizationIntegration - { - OrganizationId = organizationId, - Type = IntegrationType.Slack, - Configuration = JsonSerializer.Serialize(new SlackIntegration(token)), - }); - var location = $"/organizations/{organizationId}/integrations/{integration.Id}"; + integration.Configuration = JsonSerializer.Serialize(new SlackIntegration(token)); + await integrationRepository.UpsertAsync(integration); + var location = $"/organizations/{integration.OrganizationId}/integrations/{integration.Id}"; return Created(location, new OrganizationIntegrationResponseModel(integration)); } } diff --git a/src/Api/AdminConsole/Controllers/TeamsIntegrationController.cs b/src/Api/AdminConsole/Controllers/TeamsIntegrationController.cs new file mode 100644 index 0000000000..36d107bbcc --- /dev/null +++ b/src/Api/AdminConsole/Controllers/TeamsIntegrationController.cs @@ -0,0 +1,144 @@ +using System.Text.Json; +using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; + +namespace Bit.Api.AdminConsole.Controllers; + +[Route("organizations")] +[Authorize("Application")] +public class TeamsIntegrationController( + ICurrentContext currentContext, + IOrganizationIntegrationRepository integrationRepository, + IBot bot, + IBotFrameworkHttpAdapter adapter, + ITeamsService teamsService, + TimeProvider timeProvider) : Controller +{ + [HttpGet("{organizationId:guid}/integrations/teams/redirect")] + public async Task RedirectAsync(Guid organizationId) + { + if (!await currentContext.OrganizationOwner(organizationId)) + { + throw new NotFoundException(); + } + + var callbackUrl = Url.RouteUrl( + routeName: "TeamsIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var integrations = await integrationRepository.GetManyByOrganizationAsync(organizationId); + var integration = integrations.FirstOrDefault(i => i.Type == IntegrationType.Teams); + + if (integration is null) + { + // No teams integration exists, create Initiated version + integration = await integrationRepository.CreateAsync(new OrganizationIntegration + { + OrganizationId = organizationId, + Type = IntegrationType.Teams, + Configuration = null, + }); + } + else if (integration.Configuration is not null) + { + // A Completed (fully configured) Teams integration already exists, throw to prevent overriding + throw new BadRequestException("There already exists a Teams integration for this organization"); + + } // An Initiated teams integration exits, re-use it and kick off a new OAuth flow + + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + var redirectUrl = teamsService.GetRedirectUrl( + callbackUrl: callbackUrl, + state: state.ToString() + ); + + if (string.IsNullOrEmpty(redirectUrl)) + { + throw new NotFoundException(); + } + + return Redirect(redirectUrl); + } + + [HttpGet("integrations/teams/create", Name = "TeamsIntegration_Create")] + [AllowAnonymous] + public async Task CreateAsync([FromQuery] string code, [FromQuery] string state) + { + var oAuthState = IntegrationOAuthState.FromString(state: state, timeProvider: timeProvider); + if (oAuthState is null) + { + throw new NotFoundException(); + } + + // Fetch existing Initiated record + var integration = await integrationRepository.GetByIdAsync(oAuthState.IntegrationId); + if (integration is null || + integration.Type != IntegrationType.Teams || + integration.Configuration is not null) + { + throw new NotFoundException(); + } + + // Verify Organization matches hash + if (!oAuthState.ValidateOrg(integration.OrganizationId)) + { + throw new NotFoundException(); + } + + var callbackUrl = Url.RouteUrl( + routeName: "TeamsIntegration_Create", + values: null, + protocol: currentContext.HttpContext.Request.Scheme, + host: currentContext.HttpContext.Request.Host.ToUriComponent() + ); + if (string.IsNullOrEmpty(callbackUrl)) + { + throw new BadRequestException("Unable to build callback Url"); + } + + var token = await teamsService.ObtainTokenViaOAuth(code, callbackUrl); + if (string.IsNullOrEmpty(token)) + { + throw new BadRequestException("Invalid response from Teams."); + } + + var teams = await teamsService.GetJoinedTeamsAsync(token); + + if (!teams.Any()) + { + throw new BadRequestException("No teams were found."); + } + + var teamsIntegration = new TeamsIntegration(TenantId: teams[0].TenantId, Teams: teams); + integration.Configuration = JsonSerializer.Serialize(teamsIntegration); + await integrationRepository.UpsertAsync(integration); + + var location = $"/organizations/{integration.OrganizationId}/integrations/{integration.Id}"; + return Created(location, new OrganizationIntegrationResponseModel(integration)); + } + + [Route("integrations/teams/incoming")] + [AllowAnonymous] + [HttpPost] + public async Task IncomingPostAsync() + { + await adapter.ProcessAsync(Request, Response, bot); + } +} diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs index 7d1efe2315..8581c4ae1f 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModel.cs @@ -38,6 +38,10 @@ public class OrganizationIntegrationConfigurationRequestModel return !string.IsNullOrWhiteSpace(Template) && Configuration is null && IsFiltersValid(); + case IntegrationType.Teams: + return !string.IsNullOrWhiteSpace(Template) && + Configuration is null && + IsFiltersValid(); default: return false; diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs index 92d65ab8fe..668afe70bf 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrgnizationIntegrationRequestModel.cs @@ -35,7 +35,7 @@ public class OrganizationIntegrationRequestModel : IValidatableObject case IntegrationType.CloudBillingSync or IntegrationType.Scim: yield return new ValidationResult($"{nameof(Type)} integrations are not yet supported.", [nameof(Type)]); break; - case IntegrationType.Slack: + case IntegrationType.Slack or IntegrationType.Teams: yield return new ValidationResult($"{nameof(Type)} integrations cannot be created directly.", [nameof(Type)]); break; case IntegrationType.Webhook: diff --git a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs index 0e31deacd1..2dc7dfa7cd 100644 --- a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs @@ -1,29 +1,30 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; -using System.Text.Json; +using System.ComponentModel.DataAnnotations; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; namespace Bit.Api.AdminConsole.Models.Request; public class PolicyRequestModel { - [Required] - public PolicyType? Type { get; set; } [Required] public bool? Enabled { get; set; } - public Dictionary Data { get; set; } + public Dictionary? Data { get; set; } - public async Task ToPolicyUpdateAsync(Guid organizationId, ICurrentContext currentContext) => new() + public async Task ToPolicyUpdateAsync(Guid organizationId, PolicyType type, ICurrentContext currentContext) { - Type = Type!.Value, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); + + return new() + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = performedBy + }; + } } diff --git a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs index 1f50c384a3..41cebe8b9b 100644 --- a/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/Providers/ProviderSetupRequestModel.cs @@ -3,8 +3,7 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; -using Bit.Api.Billing.Models.Requests; -using Bit.Api.Models.Request; +using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Utilities; @@ -28,8 +27,9 @@ public class ProviderSetupRequestModel [Required] public string Key { get; set; } [Required] - public ExpandedTaxInfoUpdateRequestModel TaxInfo { get; set; } - public TokenizedPaymentSourceRequestBody PaymentSource { get; set; } + public MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; } + [Required] + public BillingAddressRequest BillingAddress { get; set; } public virtual Provider ToProvider(Provider provider) { diff --git a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs index fcdc49882b..2e2868a78a 100644 --- a/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs +++ b/src/Api/AdminConsole/Models/Request/SavePolicyRequest.cs @@ -1,10 +1,9 @@ using System.ComponentModel.DataAnnotations; -using System.Text.Json; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Context; -using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request; @@ -15,47 +14,12 @@ public class SavePolicyRequest public Dictionary? Metadata { get; set; } - public async Task ToSavePolicyModelAsync(Guid organizationId, ICurrentContext currentContext) + public async Task ToSavePolicyModelAsync(Guid organizationId, PolicyType type, ICurrentContext currentContext) { + var policyUpdate = await Policy.ToPolicyUpdateAsync(organizationId, type, currentContext); + var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, type); var performedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)); - var updatedPolicy = new PolicyUpdate() - { - Type = Policy.Type!.Value, - OrganizationId = organizationId, - Data = Policy.Data != null ? JsonSerializer.Serialize(Policy.Data) : null, - Enabled = Policy.Enabled.GetValueOrDefault(), - }; - - var metadata = MapToPolicyMetadata(); - - return new SavePolicyModel(updatedPolicy, performedBy, metadata); - } - - private IPolicyMetadataModel MapToPolicyMetadata() - { - if (Metadata == null) - { - return new EmptyMetadataModel(); - } - - return Policy?.Type switch - { - PolicyType.OrganizationDataOwnership => MapToPolicyMetadata(), - _ => new EmptyMetadataModel() - }; - } - - private IPolicyMetadataModel MapToPolicyMetadata() where T : IPolicyMetadataModel, new() - { - try - { - var json = JsonSerializer.Serialize(Metadata); - return CoreHelpers.LoadClassFromJsonData(json); - } - catch - { - return new EmptyMetadataModel(); - } + return new SavePolicyModel(policyUpdate, performedBy, metadata); } } diff --git a/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs new file mode 100644 index 0000000000..c172c45e94 --- /dev/null +++ b/src/Api/AdminConsole/Models/Response/BaseProfileOrganizationResponseModel.cs @@ -0,0 +1,127 @@ +using System.Text.Json.Serialization; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Api.AdminConsole.Models.Response; + +/// +/// Contains organization properties for both OrganizationUsers and ProviderUsers. +/// Any organization properties in sync data should be added to this class so they are populated for both +/// members and providers. +/// +public abstract class BaseProfileOrganizationResponseModel : ResponseModel +{ + protected BaseProfileOrganizationResponseModel( + string type, IProfileOrganizationDetails organizationDetails) : base(type) + { + Id = organizationDetails.OrganizationId; + UserId = organizationDetails.UserId; + Name = organizationDetails.Name; + Enabled = organizationDetails.Enabled; + Identifier = organizationDetails.Identifier; + ProductTierType = organizationDetails.PlanType.GetProductTier(); + UsePolicies = organizationDetails.UsePolicies; + UseSso = organizationDetails.UseSso; + UseKeyConnector = organizationDetails.UseKeyConnector; + UseScim = organizationDetails.UseScim; + UseGroups = organizationDetails.UseGroups; + UseDirectory = organizationDetails.UseDirectory; + UseEvents = organizationDetails.UseEvents; + UseTotp = organizationDetails.UseTotp; + Use2fa = organizationDetails.Use2fa; + UseApi = organizationDetails.UseApi; + UseResetPassword = organizationDetails.UseResetPassword; + UsersGetPremium = organizationDetails.UsersGetPremium; + UseCustomPermissions = organizationDetails.UseCustomPermissions; + UseActivateAutofillPolicy = organizationDetails.PlanType.GetProductTier() == ProductTierType.Enterprise; + UseRiskInsights = organizationDetails.UseRiskInsights; + UseOrganizationDomains = organizationDetails.UseOrganizationDomains; + UseAdminSponsoredFamilies = organizationDetails.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organizationDetails.UseAutomaticUserConfirmation; + UseSecretsManager = organizationDetails.UseSecretsManager; + UsePasswordManager = organizationDetails.UsePasswordManager; + SelfHost = organizationDetails.SelfHost; + Seats = organizationDetails.Seats; + MaxCollections = organizationDetails.MaxCollections; + MaxStorageGb = organizationDetails.MaxStorageGb; + Key = organizationDetails.Key; + HasPublicAndPrivateKeys = organizationDetails.PublicKey != null && organizationDetails.PrivateKey != null; + SsoBound = !string.IsNullOrWhiteSpace(organizationDetails.SsoExternalId); + ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(organizationDetails.ResetPasswordKey); + ProviderId = organizationDetails.ProviderId; + ProviderName = organizationDetails.ProviderName; + ProviderType = organizationDetails.ProviderType; + LimitCollectionCreation = organizationDetails.LimitCollectionCreation; + LimitCollectionDeletion = organizationDetails.LimitCollectionDeletion; + LimitItemDeletion = organizationDetails.LimitItemDeletion; + AllowAdminAccessToAllCollectionItems = organizationDetails.AllowAdminAccessToAllCollectionItems; + SsoEnabled = organizationDetails.SsoEnabled ?? false; + if (organizationDetails.SsoConfig != null) + { + var ssoConfigData = SsoConfigurationData.Deserialize(organizationDetails.SsoConfig); + KeyConnectorEnabled = ssoConfigData.MemberDecryptionType == MemberDecryptionType.KeyConnector && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); + KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; + SsoMemberDecryptionType = ssoConfigData.MemberDecryptionType; + } + } + + public Guid Id { get; set; } + [JsonConverter(typeof(HtmlEncodingStringConverter))] + public string Name { get; set; } = null!; + public bool Enabled { get; set; } + public string? Identifier { get; set; } + public ProductTierType ProductTierType { get; set; } + public bool UsePolicies { get; set; } + public bool UseSso { get; set; } + public bool UseKeyConnector { get; set; } + public bool UseScim { get; set; } + public bool UseGroups { get; set; } + public bool UseDirectory { get; set; } + public bool UseEvents { get; set; } + public bool UseTotp { get; set; } + public bool Use2fa { get; set; } + public bool UseApi { get; set; } + public bool UseResetPassword { get; set; } + public bool UseSecretsManager { get; set; } + public bool UsePasswordManager { get; set; } + public bool UsersGetPremium { get; set; } + public bool UseCustomPermissions { get; set; } + public bool UseActivateAutofillPolicy { get; set; } + public bool UseRiskInsights { get; set; } + public bool UseOrganizationDomains { get; set; } + public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool SelfHost { get; set; } + public int? Seats { get; set; } + public short? MaxCollections { get; set; } + public short? MaxStorageGb { get; set; } + public string? Key { get; set; } + public bool HasPublicAndPrivateKeys { get; set; } + public bool SsoBound { get; set; } + public bool ResetPasswordEnrolled { get; set; } + public bool LimitCollectionCreation { get; set; } + public bool LimitCollectionDeletion { get; set; } + public bool LimitItemDeletion { get; set; } + public bool AllowAdminAccessToAllCollectionItems { get; set; } + public Guid? ProviderId { get; set; } + [JsonConverter(typeof(HtmlEncodingStringConverter))] + public string? ProviderName { get; set; } + public ProviderType? ProviderType { get; set; } + public bool SsoEnabled { get; set; } + public bool KeyConnectorEnabled { get; set; } + public string? KeyConnectorUrl { get; set; } + public MemberDecryptionType? SsoMemberDecryptionType { get; set; } + public bool AccessSecretsManager { get; set; } + public Guid? UserId { get; set; } + public OrganizationUserStatusType Status { get; set; } + public OrganizationUserType Type { get; set; } + public Permissions? Permissions { get; set; } +} diff --git a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs b/src/Api/AdminConsole/Models/Response/EventResponseModel.cs index bf02d8b00f..c259bc3bc4 100644 --- a/src/Api/AdminConsole/Models/Response/EventResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/EventResponseModel.cs @@ -35,6 +35,7 @@ public class EventResponseModel : ResponseModel SecretId = ev.SecretId; ProjectId = ev.ProjectId; ServiceAccountId = ev.ServiceAccountId; + GrantedServiceAccountId = ev.GrantedServiceAccountId; } public EventType Type { get; set; } @@ -58,4 +59,5 @@ public class EventResponseModel : ResponseModel public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs index c7906318e8..d070375d88 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationConfigurationResponseModel.cs @@ -2,8 +2,6 @@ using Bit.Core.Enums; using Bit.Core.Models.Api; -#nullable enable - namespace Bit.Api.AdminConsole.Models.Response.Organizations; public class OrganizationIntegrationConfigurationResponseModel : ResponseModel @@ -11,8 +9,6 @@ public class OrganizationIntegrationConfigurationResponseModel : ResponseModel public OrganizationIntegrationConfigurationResponseModel(OrganizationIntegrationConfiguration organizationIntegrationConfiguration, string obj = "organizationIntegrationConfiguration") : base(obj) { - ArgumentNullException.ThrowIfNull(organizationIntegrationConfiguration); - Id = organizationIntegrationConfiguration.Id; Configuration = organizationIntegrationConfiguration.Configuration; CreationDate = organizationIntegrationConfiguration.CreationDate; diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs index f062ff46a2..0c31e07bef 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModel.cs @@ -1,9 +1,9 @@ -using Bit.Core.AdminConsole.Entities; +using System.Text.Json; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.Enums; using Bit.Core.Models.Api; -#nullable enable - namespace Bit.Api.AdminConsole.Models.Response.Organizations; public class OrganizationIntegrationResponseModel : ResponseModel @@ -21,4 +21,39 @@ public class OrganizationIntegrationResponseModel : ResponseModel public Guid Id { get; set; } public IntegrationType Type { get; set; } public string? Configuration { get; set; } + + public OrganizationIntegrationStatus Status => Type switch + { + // Not yet implemented, shouldn't be present, NotApplicable + IntegrationType.CloudBillingSync => OrganizationIntegrationStatus.NotApplicable, + IntegrationType.Scim => OrganizationIntegrationStatus.NotApplicable, + + // Webhook is allowed to be null. If it's present, it's Completed + IntegrationType.Webhook => OrganizationIntegrationStatus.Completed, + + // If present and the configuration is null, OAuth has been initiated, and we are + // waiting on the return call + IntegrationType.Slack => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Initiated + : OrganizationIntegrationStatus.Completed, + + // If present and the configuration is null, OAuth has been initiated, and we are + // waiting on the return OAuth call. If Configuration is not null and IsCompleted is true, + // then we've received the app install bot callback, and it's Completed. Otherwise, + // it is In Progress while we await the app install bot callback. + IntegrationType.Teams => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Initiated + : (JsonSerializer.Deserialize(Configuration)?.IsCompleted ?? false) + ? OrganizationIntegrationStatus.Completed + : OrganizationIntegrationStatus.InProgress, + + // HEC and Datadog should only be allowed to be created non-null. + // If they are null, they are Invalid + IntegrationType.Hec => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Invalid + : OrganizationIntegrationStatus.Completed, + IntegrationType.Datadog => string.IsNullOrWhiteSpace(Configuration) + ? OrganizationIntegrationStatus.Invalid + : OrganizationIntegrationStatus.Completed, + }; } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs index b34765fb19..8006a85734 100644 --- a/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/Organizations/OrganizationResponseModel.cs @@ -70,6 +70,7 @@ public class OrganizationResponseModel : ResponseModel UseRiskInsights = organization.UseRiskInsights; UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; } public Guid Id { get; set; } @@ -118,6 +119,7 @@ public class OrganizationResponseModel : ResponseModel public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } } public class OrganizationSubscriptionResponseModel : OrganizationResponseModel diff --git a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs index fd2bfe06dc..97a58d038a 100644 --- a/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/ProfileOrganizationResponseModel.cs @@ -1,148 +1,47 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models.Data; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Enums; -using Bit.Core.Models.Api; +using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Response; -public class ProfileOrganizationResponseModel : ResponseModel +/// +/// Sync data for organization members and their organization. +/// Note: see for organization sync data received by provider users. +/// +public class ProfileOrganizationResponseModel : BaseProfileOrganizationResponseModel { - public ProfileOrganizationResponseModel(string str) : base(str) { } - public ProfileOrganizationResponseModel( - OrganizationUserOrganizationDetails organization, + OrganizationUserOrganizationDetails organizationDetails, IEnumerable organizationIdsClaimingUser) - : this("profileOrganization") + : base("profileOrganization", organizationDetails) { - Id = organization.OrganizationId; - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UseSecretsManager = organization.UseSecretsManager; - UsePasswordManager = organization.UsePasswordManager; - UsersGetPremium = organization.UsersGetPremium; - UseCustomPermissions = organization.UseCustomPermissions; - UseActivateAutofillPolicy = organization.PlanType.GetProductTier() == ProductTierType.Enterprise; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; - Status = organization.Status; - Type = organization.Type; - Enabled = organization.Enabled; - SsoBound = !string.IsNullOrWhiteSpace(organization.SsoExternalId); - Identifier = organization.Identifier; - Permissions = CoreHelpers.LoadClassFromJsonData(organization.Permissions); - ResetPasswordEnrolled = !string.IsNullOrWhiteSpace(organization.ResetPasswordKey); - UserId = organization.UserId; - OrganizationUserId = organization.OrganizationUserId; - ProviderId = organization.ProviderId; - ProviderName = organization.ProviderName; - ProviderType = organization.ProviderType; - FamilySponsorshipFriendlyName = organization.FamilySponsorshipFriendlyName; - IsAdminInitiated = organization.IsAdminInitiated ?? false; - FamilySponsorshipAvailable = (FamilySponsorshipFriendlyName == null || IsAdminInitiated) && + Status = organizationDetails.Status; + Type = organizationDetails.Type; + OrganizationUserId = organizationDetails.OrganizationUserId; + UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organizationDetails.OrganizationId); + Permissions = CoreHelpers.LoadClassFromJsonData(organizationDetails.Permissions); + IsAdminInitiated = organizationDetails.IsAdminInitiated ?? false; + FamilySponsorshipFriendlyName = organizationDetails.FamilySponsorshipFriendlyName; + FamilySponsorshipLastSyncDate = organizationDetails.FamilySponsorshipLastSyncDate; + FamilySponsorshipToDelete = organizationDetails.FamilySponsorshipToDelete; + FamilySponsorshipValidUntil = organizationDetails.FamilySponsorshipValidUntil; + FamilySponsorshipAvailable = (organizationDetails.FamilySponsorshipFriendlyName == null || IsAdminInitiated) && StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise) - .UsersCanSponsor(organization); - ProductTierType = organization.PlanType.GetProductTier(); - FamilySponsorshipLastSyncDate = organization.FamilySponsorshipLastSyncDate; - FamilySponsorshipToDelete = organization.FamilySponsorshipToDelete; - FamilySponsorshipValidUntil = organization.FamilySponsorshipValidUntil; - AccessSecretsManager = organization.AccessSecretsManager; - LimitCollectionCreation = organization.LimitCollectionCreation; - LimitCollectionDeletion = organization.LimitCollectionDeletion; - LimitItemDeletion = organization.LimitItemDeletion; - AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; - UserIsClaimedByOrganization = organizationIdsClaimingUser.Contains(organization.OrganizationId); - UseRiskInsights = organization.UseRiskInsights; - UseOrganizationDomains = organization.UseOrganizationDomains; - UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; - SsoEnabled = organization.SsoEnabled ?? false; - - if (organization.SsoConfig != null) - { - var ssoConfigData = SsoConfigurationData.Deserialize(organization.SsoConfig); - KeyConnectorEnabled = ssoConfigData.MemberDecryptionType == MemberDecryptionType.KeyConnector && !string.IsNullOrEmpty(ssoConfigData.KeyConnectorUrl); - KeyConnectorUrl = ssoConfigData.KeyConnectorUrl; - SsoMemberDecryptionType = ssoConfigData.MemberDecryptionType; - } + .UsersCanSponsor(organizationDetails); + AccessSecretsManager = organizationDetails.AccessSecretsManager; } - public Guid Id { get; set; } - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } - public bool UsePolicies { get; set; } - public bool UseSso { get; set; } - public bool UseKeyConnector { get; set; } - public bool UseScim { get; set; } - public bool UseGroups { get; set; } - public bool UseDirectory { get; set; } - public bool UseEvents { get; set; } - public bool UseTotp { get; set; } - public bool Use2fa { get; set; } - public bool UseApi { get; set; } - public bool UseResetPassword { get; set; } - public bool UseSecretsManager { get; set; } - public bool UsePasswordManager { get; set; } - public bool UsersGetPremium { get; set; } - public bool UseCustomPermissions { get; set; } - public bool UseActivateAutofillPolicy { get; set; } - public bool SelfHost { get; set; } - public int? Seats { get; set; } - public short? MaxCollections { get; set; } - public short? MaxStorageGb { get; set; } - public string Key { get; set; } - public OrganizationUserStatusType Status { get; set; } - public OrganizationUserType Type { get; set; } - public bool Enabled { get; set; } - public bool SsoBound { get; set; } - public string Identifier { get; set; } - public Permissions Permissions { get; set; } - public bool ResetPasswordEnrolled { get; set; } - public Guid? UserId { get; set; } public Guid OrganizationUserId { get; set; } - public bool HasPublicAndPrivateKeys { get; set; } - public Guid? ProviderId { get; set; } - [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } - public ProviderType? ProviderType { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } + public bool UserIsClaimedByOrganization { get; set; } + public string? FamilySponsorshipFriendlyName { get; set; } public bool FamilySponsorshipAvailable { get; set; } - public ProductTierType ProductTierType { get; set; } - public bool KeyConnectorEnabled { get; set; } - public string KeyConnectorUrl { get; set; } public DateTime? FamilySponsorshipLastSyncDate { get; set; } public DateTime? FamilySponsorshipValidUntil { get; set; } public bool? FamilySponsorshipToDelete { get; set; } - public bool AccessSecretsManager { get; set; } - public bool LimitCollectionCreation { get; set; } - public bool LimitCollectionDeletion { get; set; } - public bool LimitItemDeletion { get; set; } - public bool AllowAdminAccessToAllCollectionItems { get; set; } + public bool IsAdminInitiated { get; set; } /// - /// Obsolete. - /// See + /// Obsolete property for backward compatibility /// [Obsolete("Please use UserIsClaimedByOrganization instead. This property will be removed in a future version.")] public bool UserIsManagedByOrganization @@ -150,18 +49,4 @@ public class ProfileOrganizationResponseModel : ResponseModel get => UserIsClaimedByOrganization; set => UserIsClaimedByOrganization = value; } - /// - /// Indicates if the user is claimed by the organization. - /// - /// - /// A user is claimed by an organization if the user's email domain is verified by the organization and the user is a member. - /// The organization must be enabled and able to have verified domains. - /// - public bool UserIsClaimedByOrganization { get; set; } - public bool UseRiskInsights { get; set; } - public bool UseOrganizationDomains { get; set; } - public bool UseAdminSponsoredFamilies { get; set; } - public bool IsAdminInitiated { get; set; } - public bool SsoEnabled { get; set; } - public MemberDecryptionType? SsoMemberDecryptionType { get; set; } } diff --git a/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs b/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs index 24b6fed704..fe31b8cb55 100644 --- a/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs +++ b/src/Api/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModel.cs @@ -1,56 +1,24 @@ using Bit.Core.AdminConsole.Models.Data.Provider; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; using Bit.Core.Enums; using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Models.Response; -public class ProfileProviderOrganizationResponseModel : ProfileOrganizationResponseModel +/// +/// Sync data for provider users and their managed organizations. +/// Note: see for organization sync data received by organization members. +/// +public class ProfileProviderOrganizationResponseModel : BaseProfileOrganizationResponseModel { - public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organization) - : base("profileProviderOrganization") + public ProfileProviderOrganizationResponseModel(ProviderUserOrganizationDetails organizationDetails) + : base("profileProviderOrganization", organizationDetails) { - Id = organization.OrganizationId; - Name = organization.Name; - UsePolicies = organization.UsePolicies; - UseSso = organization.UseSso; - UseKeyConnector = organization.UseKeyConnector; - UseScim = organization.UseScim; - UseGroups = organization.UseGroups; - UseDirectory = organization.UseDirectory; - UseEvents = organization.UseEvents; - UseTotp = organization.UseTotp; - Use2fa = organization.Use2fa; - UseApi = organization.UseApi; - UseResetPassword = organization.UseResetPassword; - UsersGetPremium = organization.UsersGetPremium; - UseCustomPermissions = organization.UseCustomPermissions; - UseActivateAutofillPolicy = organization.PlanType.GetProductTier() == ProductTierType.Enterprise; - SelfHost = organization.SelfHost; - Seats = organization.Seats; - MaxCollections = organization.MaxCollections; - MaxStorageGb = organization.MaxStorageGb; - Key = organization.Key; - HasPublicAndPrivateKeys = organization.PublicKey != null && organization.PrivateKey != null; Status = OrganizationUserStatusType.Confirmed; // Provider users are always confirmed Type = OrganizationUserType.Owner; // Provider users behave like Owners - Enabled = organization.Enabled; - SsoBound = false; - Identifier = organization.Identifier; + ProviderId = organizationDetails.ProviderId; + ProviderName = organizationDetails.ProviderName; + ProviderType = organizationDetails.ProviderType; Permissions = new Permissions(); - ResetPasswordEnrolled = false; - UserId = organization.UserId; - ProviderId = organization.ProviderId; - ProviderName = organization.ProviderName; - ProviderType = organization.ProviderType; - ProductTierType = organization.PlanType.GetProductTier(); - LimitCollectionCreation = organization.LimitCollectionCreation; - LimitCollectionDeletion = organization.LimitCollectionDeletion; - LimitItemDeletion = organization.LimitItemDeletion; - AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems; - UseRiskInsights = organization.UseRiskInsights; - UseOrganizationDomains = organization.UseOrganizationDomains; - UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + AccessSecretsManager = false; // Provider users cannot access Secrets Manager } } diff --git a/src/Api/AdminConsole/Public/Controllers/EventsController.cs b/src/Api/AdminConsole/Public/Controllers/EventsController.cs index 3dd55d51e2..19edbdd5a6 100644 --- a/src/Api/AdminConsole/Public/Controllers/EventsController.cs +++ b/src/Api/AdminConsole/Public/Controllers/EventsController.cs @@ -4,9 +4,11 @@ using System.Net; using Bit.Api.Models.Public.Request; using Bit.Api.Models.Public.Response; +using Bit.Api.Utilities.DiagnosticTools; using Bit.Core.Context; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Services; using Bit.Core.Vault.Repositories; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -20,15 +22,21 @@ public class EventsController : Controller private readonly IEventRepository _eventRepository; private readonly ICipherRepository _cipherRepository; private readonly ICurrentContext _currentContext; + private readonly ILogger _logger; + private readonly IFeatureService _featureService; public EventsController( IEventRepository eventRepository, ICipherRepository cipherRepository, - ICurrentContext currentContext) + ICurrentContext currentContext, + ILogger logger, + IFeatureService featureService) { _eventRepository = eventRepository; _cipherRepository = cipherRepository; _currentContext = currentContext; + _logger = logger; + _featureService = featureService; } /// @@ -69,6 +77,8 @@ public class EventsController : Controller var eventResponses = result.Data.Select(e => new EventResponseModel(e)); var response = new PagedListResponseModel(eventResponses, result.ContinuationToken); + + _logger.LogAggregateData(_featureService, _currentContext.OrganizationId!.Value, response, request); return new JsonResult(response); } } diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 7bfe5648b6..3b2e82121d 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Net; +using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; @@ -24,11 +21,9 @@ public class MembersController : Controller private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IGroupRepository _groupRepository; private readonly IOrganizationService _organizationService; - private readonly IUserService _userService; private readonly ICurrentContext _currentContext; private readonly IUpdateOrganizationUserCommand _updateOrganizationUserCommand; private readonly IUpdateOrganizationUserGroupsCommand _updateOrganizationUserGroupsCommand; - private readonly IApplicationCacheService _applicationCacheService; private readonly IPaymentService _paymentService; private readonly IOrganizationRepository _organizationRepository; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; @@ -39,11 +34,9 @@ public class MembersController : Controller IOrganizationUserRepository organizationUserRepository, IGroupRepository groupRepository, IOrganizationService organizationService, - IUserService userService, ICurrentContext currentContext, IUpdateOrganizationUserCommand updateOrganizationUserCommand, IUpdateOrganizationUserGroupsCommand updateOrganizationUserGroupsCommand, - IApplicationCacheService applicationCacheService, IPaymentService paymentService, IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, @@ -53,11 +46,9 @@ public class MembersController : Controller _organizationUserRepository = organizationUserRepository; _groupRepository = groupRepository; _organizationService = organizationService; - _userService = userService; _currentContext = currentContext; _updateOrganizationUserCommand = updateOrganizationUserCommand; _updateOrganizationUserGroupsCommand = updateOrganizationUserGroupsCommand; - _applicationCacheService = applicationCacheService; _paymentService = paymentService; _organizationRepository = organizationRepository; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; @@ -115,19 +106,18 @@ public class MembersController : Controller /// /// /// Returns a list of your organization's members. - /// Member objects listed in this call do not include information about their associated collections. + /// Member objects listed in this call include information about their associated collections. /// [HttpGet] [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] public async Task List() { - var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId.Value); - // TODO: Get all CollectionUser associations for the organization and marry them up here for the response. + var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId!.Value, includeCollections: true); var orgUsersTwoFactorIsEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(organizationUserUserDetails); var memberResponses = organizationUserUserDetails.Select(u => { - return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, null); + return new MemberResponseModel(u, orgUsersTwoFactorIsEnabled.FirstOrDefault(tuple => tuple.user == u).twoFactorIsEnabled, u.Collections); }); var response = new ListResponseModel(memberResponses); return new JsonResult(response); @@ -158,7 +148,7 @@ public class MembersController : Controller invite.AccessSecretsManager = hasStandaloneSecretsManager; - var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId.Value, null, + var user = await _organizationService.InviteUserAsync(_currentContext.OrganizationId!.Value, null, systemUser: null, invite, model.ExternalId); var response = new MemberResponseModel(user, invite.Collections); return new JsonResult(response); @@ -188,12 +178,12 @@ public class MembersController : Controller var updatedUser = model.ToOrganizationUser(existingUser); var associations = model.Collections?.Select(c => c.ToCollectionAccessSelection()).ToList(); await _updateOrganizationUserCommand.UpdateUserAsync(updatedUser, existingUserType, null, associations, model.Groups); - MemberResponseModel response = null; + MemberResponseModel response; if (existingUser.UserId.HasValue) { var existingUserDetails = await _organizationUserRepository.GetDetailsByIdAsync(id); - response = new MemberResponseModel(existingUserDetails, - await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails), associations); + response = new MemberResponseModel(existingUserDetails!, + await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(existingUserDetails!), associations); } else { @@ -242,7 +232,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId.Value, id, null); + await _removeOrganizationUserCommand.RemoveUserAsync(_currentContext.OrganizationId!.Value, id, null); return new OkResult(); } @@ -264,7 +254,7 @@ public class MembersController : Controller { return new NotFoundResult(); } - await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId.Value, null, id); + await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId!.Value, null, id); return new OkResult(); } } diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs index 1caf9cb068..be0997f271 100644 --- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs @@ -5,11 +5,15 @@ using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Context; +using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -22,18 +26,24 @@ public class PoliciesController : Controller private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly ICurrentContext _currentContext; + private readonly IFeatureService _featureService; private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public PoliciesController( IPolicyRepository policyRepository, IPolicyService policyService, ICurrentContext currentContext, - ISavePolicyCommand savePolicyCommand) + IFeatureService featureService, + ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand) { _policyRepository = policyRepository; _policyService = policyService; _currentContext = currentContext; + _featureService = featureService; _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } /// @@ -87,8 +97,17 @@ public class PoliciesController : Controller [ProducesResponseType((int)HttpStatusCode.NotFound)] public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) { - var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type); - var policy = await _savePolicyCommand.SaveAsync(policyUpdate); + Policy policy; + if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) + { + var savePolicyModel = model.ToSavePolicyModel(_currentContext.OrganizationId!.Value, type); + policy = await _vNextSavePolicyCommand.SaveAsync(savePolicyModel); + } + else + { + var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type); + policy = await _savePolicyCommand.SaveAsync(policyUpdate); + } var response = new PolicyResponseModel(policy); return new JsonResult(response); diff --git a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs index eb56690462..f81d9153b2 100644 --- a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs @@ -1,19 +1,44 @@ -using System.Text.Json; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Utilities; using Bit.Core.Enums; namespace Bit.Api.AdminConsole.Public.Models.Request; public class PolicyUpdateRequestModel : PolicyBaseModel { - public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) => new() + public Dictionary? Metadata { get; set; } + + public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) { - Type = type, - OrganizationId = organizationId, - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Enabled = Enabled.GetValueOrDefault(), - PerformedBy = new SystemUser(EventSystemUser.PublicApi) - }; + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + + return new() + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = new SystemUser(EventSystemUser.PublicApi) + }; + } + + public SavePolicyModel ToSavePolicyModel(Guid organizationId, PolicyType type) + { + var serializedData = PolicyDataValidator.ValidateAndSerialize(Data, type); + + var policyUpdate = new PolicyUpdate + { + Type = type, + OrganizationId = organizationId, + Data = serializedData, + Enabled = Enabled.GetValueOrDefault() + }; + + var performedBy = new SystemUser(EventSystemUser.PublicApi); + var metadata = PolicyDataValidator.ValidateAndDeserializeMetadata(Metadata, type); + + return new SavePolicyModel(policyUpdate, performedBy, metadata); + } } diff --git a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs index e319ead8a4..5ff12a2201 100644 --- a/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/AssociationWithPermissionsResponseModel.cs @@ -1,9 +1,15 @@ -using Bit.Core.Models.Data; +using System.Text.Json.Serialization; +using Bit.Core.Models.Data; namespace Bit.Api.AdminConsole.Public.Models.Response; public class AssociationWithPermissionsResponseModel : AssociationWithPermissionsBaseModel { + [JsonConstructor] + public AssociationWithPermissionsResponseModel() : base() + { + } + public AssociationWithPermissionsResponseModel(CollectionAccessSelection selection) { if (selection == null) diff --git a/src/Api/Auth/Controllers/AccountsController.cs b/src/Api/Auth/Controllers/AccountsController.cs index 19165a5a1c..ecf49c18c8 100644 --- a/src/Api/Auth/Controllers/AccountsController.cs +++ b/src/Api/Auth/Controllers/AccountsController.cs @@ -18,6 +18,7 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Api.Response; using Bit.Core.Repositories; using Bit.Core.Services; @@ -40,6 +41,7 @@ public class AccountsController : Controller private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IFeatureService _featureService; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; @@ -53,6 +55,7 @@ public class AccountsController : Controller ITdeOffboardingPasswordCommand tdeOffboardingPasswordCommand, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IFeatureService featureService, + IUserAccountKeysQuery userAccountKeysQuery, ITwoFactorEmailService twoFactorEmailService, IChangeKdfCommand changeKdfCommand ) @@ -66,6 +69,7 @@ public class AccountsController : Controller _tdeOffboardingPasswordCommand = tdeOffboardingPasswordCommand; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _featureService = featureService; + _userAccountKeysQuery = userAccountKeysQuery; _twoFactorEmailService = twoFactorEmailService; _changeKdfCommand = changeKdfCommand; } @@ -332,7 +336,9 @@ public class AccountsController : Controller var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); - var response = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + var accountKeys = await _userAccountKeysQuery.Run(user); + + var response = new ProfileResponseModel(user, accountKeys, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); return response; @@ -364,8 +370,9 @@ public class AccountsController : Controller var twoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var hasPremiumFromOrg = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var userAccountKeys = await _userAccountKeysQuery.Run(user); - var response = new ProfileResponseModel(user, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); + var response = new ProfileResponseModel(user, userAccountKeys, null, null, null, twoFactorEnabled, hasPremiumFromOrg, organizationIdsClaimingActiveUser); return response; } @@ -389,8 +396,9 @@ public class AccountsController : Controller var userTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var userHasPremiumFromOrganization = await _userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var accountKeys = await _userAccountKeysQuery.Run(user); - var response = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); + var response = new ProfileResponseModel(user, accountKeys, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); return response; } diff --git a/src/Api/Auth/Controllers/AuthRequestsController.cs b/src/Api/Auth/Controllers/AuthRequestsController.cs index 4da3a2f491..e9dfe17c94 100644 --- a/src/Api/Auth/Controllers/AuthRequestsController.cs +++ b/src/Api/Auth/Controllers/AuthRequestsController.cs @@ -3,7 +3,6 @@ using Bit.Api.Auth.Models.Response; using Bit.Api.Models.Response; -using Bit.Core; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Models.Api.Request.AuthRequest; @@ -12,7 +11,6 @@ using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; -using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -55,7 +53,6 @@ public class AuthRequestsController( } [HttpGet("pending")] - [RequireFeature(FeatureFlagKeys.BrowserExtensionLoginApproval)] public async Task> GetPendingAuthRequestsAsync() { var userId = _userService.GetProperUserId(User).Value; diff --git a/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs b/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs index fcf386d7ee..349bdebb88 100644 --- a/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs +++ b/src/Api/Auth/Models/Request/OrganizationSsoRequestModel.cs @@ -121,7 +121,7 @@ public class SsoConfigurationDataRequest : IValidatableObject new[] { nameof(IdpEntityId) }); } - if (!Uri.IsWellFormedUriString(IdpEntityId, UriKind.Absolute) && string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) + if (string.IsNullOrWhiteSpace(IdpSingleSignOnServiceUrl)) { yield return new ValidationResult(i18nService.GetLocalizedHtmlString("IdpSingleSignOnServiceUrlValidationError"), new[] { nameof(IdpSingleSignOnServiceUrl) }); @@ -139,6 +139,7 @@ public class SsoConfigurationDataRequest : IValidatableObject new[] { nameof(IdpSingleLogoutServiceUrl) }); } + // TODO: On server, make public certificate required for SAML2 SSO: https://bitwarden.atlassian.net/browse/PM-26028 if (!string.IsNullOrWhiteSpace(IdpX509PublicCert)) { // Validate the certificate is in a valid format diff --git a/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs b/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs new file mode 100644 index 0000000000..7a906d4838 --- /dev/null +++ b/src/Api/Billing/Attributes/NonTokenizedPaymentMethodTypeValidationAttribute.cs @@ -0,0 +1,13 @@ +using Bit.Api.Utilities; + +namespace Bit.Api.Billing.Attributes; + +public class NonTokenizedPaymentMethodTypeValidationAttribute : StringMatchesAttribute +{ + private static readonly string[] _acceptedValues = ["accountCredit"]; + + public NonTokenizedPaymentMethodTypeValidationAttribute() : base(_acceptedValues) + { + ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}"; + } +} diff --git a/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs b/src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs similarity index 62% rename from src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs rename to src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs index 227b454f9f..51e40e9999 100644 --- a/src/Api/Billing/Attributes/PaymentMethodTypeValidationAttribute.cs +++ b/src/Api/Billing/Attributes/TokenizedPaymentMethodTypeValidationAttribute.cs @@ -2,11 +2,11 @@ namespace Bit.Api.Billing.Attributes; -public class PaymentMethodTypeValidationAttribute : StringMatchesAttribute +public class TokenizedPaymentMethodTypeValidationAttribute : StringMatchesAttribute { private static readonly string[] _acceptedValues = ["bankAccount", "card", "payPal"]; - public PaymentMethodTypeValidationAttribute() : base(_acceptedValues) + public TokenizedPaymentMethodTypeValidationAttribute() : base(_acceptedValues) { ErrorMessage = $"Payment method type must be one of: {string.Join(", ", _acceptedValues)}"; } diff --git a/src/Api/Billing/Controllers/AccountsController.cs b/src/Api/Billing/Controllers/AccountsController.cs index 9411d454aa..075218dd74 100644 --- a/src/Api/Billing/Controllers/AccountsController.cs +++ b/src/Api/Billing/Controllers/AccountsController.cs @@ -1,13 +1,16 @@ #nullable enable + using Bit.Api.Models.Request; using Bit.Api.Models.Request.Accounts; using Bit.Api.Models.Response; using Bit.Api.Utilities; +using Bit.Core; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Business; using Bit.Core.Services; using Bit.Core.Settings; @@ -21,7 +24,9 @@ namespace Bit.Api.Billing.Controllers; [Authorize("Application")] public class AccountsController( IUserService userService, - ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery) : Controller + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IUserAccountKeysQuery userAccountKeysQuery, + IFeatureService featureService) : Controller { [HttpPost("premium")] public async Task PostPremiumAsync( @@ -58,8 +63,9 @@ public class AccountsController( var userTwoFactorEnabled = await twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var userHasPremiumFromOrganization = await userService.HasPremiumFromOrganization(user); var organizationIdsClaimingActiveUser = await GetOrganizationIdsClaimingUserAsync(user.Id); + var accountKeys = await userAccountKeysQuery.Run(user); - var profile = new ProfileResponseModel(user, null, null, null, userTwoFactorEnabled, + var profile = new ProfileResponseModel(user, accountKeys, null, null, null, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingActiveUser); return new PaymentResponseModel { @@ -80,16 +86,24 @@ public class AccountsController( throw new UnauthorizedAccessException(); } - if (!globalSettings.SelfHosted && user.Gateway != null) + // Only cloud-hosted users with payment gateways have subscription and discount information + if (!globalSettings.SelfHosted) { - var subscriptionInfo = await paymentService.GetSubscriptionAsync(user); - var license = await userService.GenerateLicenseAsync(user, subscriptionInfo); - return new SubscriptionResponseModel(user, subscriptionInfo, license); - } - else if (!globalSettings.SelfHosted) - { - var license = await userService.GenerateLicenseAsync(user); - return new SubscriptionResponseModel(user, license); + if (user.Gateway != null) + { + // Note: PM23341_Milestone_2 is the feature flag for the overall Milestone 2 initiative (PM-23341). + // This specific implementation (PM-26682) adds discount display functionality as part of that initiative. + // The feature flag controls the broader Milestone 2 feature set, not just this specific task. + var includeMilestone2Discount = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); + var subscriptionInfo = await paymentService.GetSubscriptionAsync(user); + var license = await userService.GenerateLicenseAsync(user, subscriptionInfo); + return new SubscriptionResponseModel(user, subscriptionInfo, license, includeMilestone2Discount); + } + else + { + var license = await userService.GenerateLicenseAsync(user); + return new SubscriptionResponseModel(user, license); + } } else { diff --git a/src/Api/Billing/Controllers/OrganizationBillingController.cs b/src/Api/Billing/Controllers/OrganizationBillingController.cs index 21b17bff67..6e4cacc155 100644 --- a/src/Api/Billing/Controllers/OrganizationBillingController.cs +++ b/src/Api/Billing/Controllers/OrganizationBillingController.cs @@ -1,16 +1,8 @@ -#nullable enable -using System.Diagnostics; -using Bit.Api.AdminConsole.Models.Request.Organizations; -using Bit.Api.Billing.Models.Requests; +using Bit.Api.Billing.Models.Requests; using Bit.Api.Billing.Models.Responses; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Organizations.Services; -using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Providers.Services; using Bit.Core.Billing.Services; -using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; using Bit.Core.Repositories; using Bit.Core.Services; @@ -28,10 +20,8 @@ public class OrganizationBillingController( IOrganizationBillingService organizationBillingService, IOrganizationRepository organizationRepository, IPaymentService paymentService, - IPricingClient pricingClient, ISubscriberService subscriberService, - IPaymentHistoryService paymentHistoryService, - IUserService userService) : BaseBillingController + IPaymentHistoryService paymentHistoryService) : BaseBillingController { [HttpGet("metadata")] public async Task GetMetadataAsync([FromRoute] Guid organizationId) @@ -48,9 +38,7 @@ public class OrganizationBillingController( return Error.NotFound(); } - var response = OrganizationMetadataResponse.From(metadata); - - return TypedResults.Ok(response); + return TypedResults.Ok(metadata); } [HttpGet("history")] @@ -264,71 +252,6 @@ public class OrganizationBillingController( return TypedResults.Ok(); } - [HttpPost("restart-subscription")] - public async Task RestartSubscriptionAsync([FromRoute] Guid organizationId, - [FromBody] OrganizationCreateRequestModel model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - if (!await currentContext.EditPaymentMethods(organizationId)) - { - return Error.Unauthorized(); - } - - var organization = await organizationRepository.GetByIdAsync(organizationId); - if (organization == null) - { - return Error.NotFound(); - } - var existingPlan = organization.PlanType; - var organizationSignup = model.ToOrganizationSignup(user); - var sale = OrganizationSale.From(organization, organizationSignup); - var plan = await pricingClient.GetPlanOrThrow(model.PlanType); - sale.Organization.PlanType = plan.Type; - sale.Organization.Plan = plan.Name; - sale.SubscriptionSetup.SkipTrial = true; - if (existingPlan == PlanType.Free && organization.GatewaySubscriptionId is not null) - { - sale.Organization.UseTotp = plan.HasTotp; - sale.Organization.UseGroups = plan.HasGroups; - sale.Organization.UseDirectory = plan.HasDirectory; - sale.Organization.SelfHost = plan.HasSelfHost; - sale.Organization.UsersGetPremium = plan.UsersGetPremium; - sale.Organization.UseEvents = plan.HasEvents; - sale.Organization.Use2fa = plan.Has2fa; - sale.Organization.UseApi = plan.HasApi; - sale.Organization.UsePolicies = plan.HasPolicies; - sale.Organization.UseSso = plan.HasSso; - sale.Organization.UseResetPassword = plan.HasResetPassword; - sale.Organization.UseKeyConnector = plan.HasKeyConnector ? organization.UseKeyConnector : false; - sale.Organization.UseScim = plan.HasScim; - sale.Organization.UseCustomPermissions = plan.HasCustomPermissions; - sale.Organization.UseOrganizationDomains = plan.HasOrganizationDomains; - sale.Organization.MaxCollections = plan.PasswordManager.MaxCollections; - } - - if (organizationSignup.PaymentMethodType == null || string.IsNullOrEmpty(organizationSignup.PaymentToken)) - { - return Error.BadRequest("A payment method is required to restart the subscription."); - } - var org = await organizationRepository.GetByIdAsync(organizationId); - Debug.Assert(org is not null, "This organization has already been found via this same ID, this should be fine."); - var paymentSource = new TokenizedPaymentSource(organizationSignup.PaymentMethodType.Value, organizationSignup.PaymentToken); - var taxInformation = TaxInformation.From(organizationSignup.TaxInfo); - await organizationBillingService.Finalize(sale); - var updatedOrg = await organizationRepository.GetByIdAsync(organizationId); - if (updatedOrg != null) - { - await organizationBillingService.UpdatePaymentMethod(updatedOrg, paymentSource, taxInformation); - } - - return TypedResults.Ok(); - } - [HttpPost("setup-business-unit")] [SelfHosted(NotSelfHostedOnly = true)] public async Task SetupBusinessUnitAsync( diff --git a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs index 8c202752de..7ca85d52a8 100644 --- a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs @@ -89,19 +89,6 @@ public class OrganizationSponsorshipsController : Controller throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator."); } - if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) - { - if (model.IsAdminInitiated.GetValueOrDefault()) - { - throw new BadRequestException(); - } - - if (!string.IsNullOrWhiteSpace(model.Notes)) - { - model.Notes = null; - } - } - var sponsorship = await _createSponsorshipCommand.CreateSponsorshipAsync( sponsoringOrg, await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), diff --git a/src/Api/Controllers/PlansController.cs b/src/Api/Billing/Controllers/PlansController.cs similarity index 66% rename from src/Api/Controllers/PlansController.cs rename to src/Api/Billing/Controllers/PlansController.cs index 11b070fb66..f9b5274780 100644 --- a/src/Api/Controllers/PlansController.cs +++ b/src/Api/Billing/Controllers/PlansController.cs @@ -3,10 +3,10 @@ using Bit.Core.Billing.Pricing; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -namespace Bit.Api.Controllers; +namespace Bit.Api.Billing.Controllers; [Route("plans")] -[Authorize("Web")] +[Authorize("Application")] public class PlansController( IPricingClient pricingClient) : Controller { @@ -18,4 +18,11 @@ public class PlansController( var responses = plans.Select(plan => new PlanResponseModel(plan)); return new ListResponseModel(responses); } + + [HttpGet("premium")] + public async Task GetPremiumPlanAsync() + { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + return TypedResults.Ok(premiumPlan); + } } diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index f7d0593812..006a7ce068 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -132,7 +132,7 @@ public class ProviderBillingController( } var subscription = await stripeAdapter.SubscriptionGetAsync(provider.GatewaySubscriptionId, - new SubscriptionGetOptions { Expand = ["customer.tax_ids", "test_clock"] }); + new SubscriptionGetOptions { Expand = ["customer.tax_ids", "discounts", "test_clock"] }); var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); diff --git a/src/Api/Billing/Controllers/TaxController.cs b/src/Api/Billing/Controllers/TaxController.cs index d2c1c36726..4ead414589 100644 --- a/src/Api/Billing/Controllers/TaxController.cs +++ b/src/Api/Billing/Controllers/TaxController.cs @@ -1,33 +1,73 @@ -using Bit.Api.Billing.Models.Requests; -using Bit.Core.Billing.Tax.Commands; +using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Models.Requests.Tax; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Organizations.Commands; +using Bit.Core.Billing.Premium.Commands; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.ModelBinding; namespace Bit.Api.Billing.Controllers; [Authorize("Application")] -[Route("tax")] +[Route("billing/tax")] public class TaxController( - IPreviewTaxAmountCommand previewTaxAmountCommand) : BaseBillingController + IPreviewOrganizationTaxCommand previewOrganizationTaxCommand, + IPreviewPremiumTaxCommand previewPremiumTaxCommand) : BaseBillingController { - [HttpPost("preview-amount/organization-trial")] - public async Task PreviewTaxAmountForOrganizationTrialAsync( - [FromBody] PreviewTaxAmountForOrganizationTrialRequestBody requestBody) + [HttpPost("organizations/subscriptions/purchase")] + public async Task PreviewOrganizationSubscriptionPurchaseTaxAsync( + [FromBody] PreviewOrganizationSubscriptionPurchaseTaxRequest request) { - var parameters = new OrganizationTrialParameters + var (purchase, billingAddress) = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(purchase, billingAddress); + return Handle(result.Map(pair => new { - PlanType = requestBody.PlanType, - ProductType = requestBody.ProductType, - TaxInformation = new OrganizationTrialParameters.TaxInformationDTO - { - Country = requestBody.TaxInformation.Country, - PostalCode = requestBody.TaxInformation.PostalCode, - TaxId = requestBody.TaxInformation.TaxId - } - }; + pair.Tax, + pair.Total + })); + } - var result = await previewTaxAmountCommand.Run(parameters); + [HttpPost("organizations/{organizationId:guid}/subscription/plan-change")] + [InjectOrganization] + public async Task PreviewOrganizationSubscriptionPlanChangeTaxAsync( + [BindNever] Organization organization, + [FromBody] PreviewOrganizationSubscriptionPlanChangeTaxRequest request) + { + var (planChange, billingAddress) = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(organization, planChange, billingAddress); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); + } - return Handle(result); + [HttpPut("organizations/{organizationId:guid}/subscription/update")] + [InjectOrganization] + public async Task PreviewOrganizationSubscriptionUpdateTaxAsync( + [BindNever] Organization organization, + [FromBody] PreviewOrganizationSubscriptionUpdateTaxRequest request) + { + var update = request.ToDomain(); + var result = await previewOrganizationTaxCommand.Run(organization, update); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); + } + + [HttpPost("premium/subscriptions/purchase")] + public async Task PreviewPremiumSubscriptionPurchaseTaxAsync( + [FromBody] PreviewPremiumSubscriptionPurchaseTaxRequest request) + { + var (purchase, billingAddress) = request.ToDomain(); + var result = await previewPremiumTaxCommand.Run(purchase, billingAddress); + return Handle(result.Map(pair => new + { + pair.Tax, + pair.Total + })); } } diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index a996290507..b01b629e4f 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Billing.Attributes; +using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; using Bit.Api.Billing.Models.Requests.Premium; using Bit.Core; @@ -67,7 +66,7 @@ public class AccountBillingVNextController( } [HttpPost("subscription")] - [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)] [InjectUser] public async Task CreateSubscriptionAsync( [BindNever] User user, diff --git a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs index ee98031dbc..64ec068a5e 100644 --- a/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/OrganizationBillingVNextController.cs @@ -2,11 +2,15 @@ using Bit.Api.AdminConsole.Authorization.Requirements; using Bit.Api.Billing.Attributes; using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.Subscriptions; using Bit.Api.Billing.Models.Requirements; +using Bit.Core; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Commands; using Bit.Core.Billing.Organizations.Queries; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -22,8 +26,10 @@ public class OrganizationBillingVNextController( ICreateBitPayInvoiceForCreditCommand createBitPayInvoiceForCreditCommand, IGetBillingAddressQuery getBillingAddressQuery, IGetCreditQuery getCreditQuery, + IGetOrganizationMetadataQuery getOrganizationMetadataQuery, IGetOrganizationWarningsQuery getOrganizationWarningsQuery, IGetPaymentMethodQuery getPaymentMethodQuery, + IRestartSubscriptionCommand restartSubscriptionCommand, IUpdateBillingAddressCommand updateBillingAddressCommand, IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingController { @@ -95,6 +101,37 @@ public class OrganizationBillingVNextController( return Handle(result); } + [Authorize] + [HttpPost("subscription/restart")] + [InjectOrganization] + public async Task RestartSubscriptionAsync( + [BindNever] Organization organization, + [FromBody] RestartSubscriptionRequest request) + { + var (paymentMethod, billingAddress) = request.ToDomain(); + var result = await updatePaymentMethodCommand.Run(organization, paymentMethod, null) + .AndThenAsync(_ => updateBillingAddressCommand.Run(organization, billingAddress)) + .AndThenAsync(_ => restartSubscriptionCommand.Run(organization)); + return Handle(result); + } + + [Authorize] + [HttpGet("metadata")] + [RequireFeature(FeatureFlagKeys.PM25379_UseNewOrganizationMetadataStructure)] + [InjectOrganization] + public async Task GetMetadataAsync( + [BindNever] Organization organization) + { + var metadata = await getOrganizationMetadataQuery.Run(organization); + + if (metadata == null) + { + return TypedResults.NotFound(); + } + + return TypedResults.Ok(metadata); + } + [Authorize] [HttpGet("warnings")] [InjectOrganization] diff --git a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs index 544753ad0f..973a7d99a1 100644 --- a/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs +++ b/src/Api/Billing/Controllers/VNext/SelfHostedAccountBillingController.cs @@ -21,7 +21,7 @@ public class SelfHostedAccountBillingController( ICreatePremiumSelfHostedSubscriptionCommand createPremiumSelfHostedSubscriptionCommand) : BaseBillingController { [HttpPost("license")] - [RequireFeature(FeatureFlagKeys.PM23385_UseNewPremiumFlow)] + [RequireFeature(FeatureFlagKeys.PM24996ImplementUpgradeFromFreeDialog)] [InjectUser] public async Task UploadLicenseAsync( [BindNever] User user, diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs new file mode 100644 index 0000000000..a3856bf173 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPlanChangeRequest.cs @@ -0,0 +1,31 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionPlanChangeRequest : IValidatableObject +{ + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public ProductTierType Tier { get; set; } + + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public PlanCadenceType Cadence { get; set; } + + public OrganizationSubscriptionPlanChange ToDomain() => new() + { + Tier = Tier, + Cadence = Cadence + }; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Tier == ProductTierType.Families && Cadence == PlanCadenceType.Monthly) + { + yield return new ValidationResult("Monthly billing cadence is not available for the Families plan."); + } + } +} diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs new file mode 100644 index 0000000000..c678b1966c --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionPurchaseRequest.cs @@ -0,0 +1,84 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionPurchaseRequest : IValidatableObject +{ + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public ProductTierType Tier { get; set; } + + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public PlanCadenceType Cadence { get; set; } + + [Required] + public required PasswordManagerPurchaseSelections PasswordManager { get; set; } + + public SecretsManagerPurchaseSelections? SecretsManager { get; set; } + + public OrganizationSubscriptionPurchase ToDomain() => new() + { + Tier = Tier, + Cadence = Cadence, + PasswordManager = new OrganizationSubscriptionPurchase.PasswordManagerSelections + { + Seats = PasswordManager.Seats, + AdditionalStorage = PasswordManager.AdditionalStorage, + Sponsored = PasswordManager.Sponsored + }, + SecretsManager = SecretsManager != null ? new OrganizationSubscriptionPurchase.SecretsManagerSelections + { + Seats = SecretsManager.Seats, + AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts, + Standalone = SecretsManager.Standalone + } : null + }; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Tier != ProductTierType.Families) + { + yield break; + } + + if (Cadence == PlanCadenceType.Monthly) + { + yield return new ValidationResult("Monthly cadence is not available on the Families plan."); + } + + if (SecretsManager != null) + { + yield return new ValidationResult("Secrets Manager is not available on the Families plan."); + } + } + + public record PasswordManagerPurchaseSelections + { + [Required] + [Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")] + public int Seats { get; set; } + + [Required] + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")] + public int AdditionalStorage { get; set; } + + public bool Sponsored { get; set; } = false; + } + + public record SecretsManagerPurchaseSelections + { + [Required] + [Range(1, 100000, ErrorMessage = "Secrets Manager seats must be between 1 and 100,000")] + public int Seats { get; set; } + + [Required] + [Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")] + public int AdditionalServiceAccounts { get; set; } + + public bool Standalone { get; set; } = false; + } +} diff --git a/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs new file mode 100644 index 0000000000..ad5c3bd609 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Organizations/OrganizationSubscriptionUpdateRequest.cs @@ -0,0 +1,48 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Organizations; + +public record OrganizationSubscriptionUpdateRequest +{ + public PasswordManagerUpdateSelections? PasswordManager { get; set; } + public SecretsManagerUpdateSelections? SecretsManager { get; set; } + + public OrganizationSubscriptionUpdate ToDomain() => new() + { + PasswordManager = + PasswordManager != null + ? new OrganizationSubscriptionUpdate.PasswordManagerSelections + { + Seats = PasswordManager.Seats, + AdditionalStorage = PasswordManager.AdditionalStorage + } + : null, + SecretsManager = + SecretsManager != null + ? new OrganizationSubscriptionUpdate.SecretsManagerSelections + { + Seats = SecretsManager.Seats, + AdditionalServiceAccounts = SecretsManager.AdditionalServiceAccounts + } + : null + }; + + public record PasswordManagerUpdateSelections + { + [Range(1, 100000, ErrorMessage = "Password Manager seats must be between 1 and 100,000")] + public int? Seats { get; set; } + + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB")] + public int? AdditionalStorage { get; set; } + } + + public record SecretsManagerUpdateSelections + { + [Range(0, 100000, ErrorMessage = "Secrets Manager seats must be between 0 and 100,000")] + public int? Seats { get; set; } + + [Range(0, 100000, ErrorMessage = "Additional service accounts must be between 0 and 100,000")] + public int? AdditionalServiceAccounts { get; set; } + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs index 5c3c47f585..0426a51f10 100644 --- a/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/BillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs index bb6e7498d7..ec1405c566 100644 --- a/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/BitPayCreditRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs index 54116e897d..ccf2b30b50 100644 --- a/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/CheckoutBillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs index b4d28017d5..29c10e6631 100644 --- a/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/MinimalBillingAddressRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; diff --git a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs index 3b50d2bf63..1311805ad4 100644 --- a/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/MinimalTokenizedPaymentMethodRequest.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Api.Billing.Attributes; using Bit.Core.Billing.Payment.Models; @@ -8,18 +7,15 @@ namespace Bit.Api.Billing.Models.Requests.Payment; public class MinimalTokenizedPaymentMethodRequest { [Required] - [PaymentMethodTypeValidation] + [TokenizedPaymentMethodTypeValidation] public required string Type { get; set; } [Required] public required string Token { get; set; } - public TokenizedPaymentMethod ToDomain() + public TokenizedPaymentMethod ToDomain() => new() { - return new TokenizedPaymentMethod - { - Type = TokenizablePaymentMethodTypeExtensions.From(Type), - Token = Token - }; - } + Type = TokenizablePaymentMethodTypeExtensions.From(Type), + Token = Token + }; } diff --git a/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs new file mode 100644 index 0000000000..d15bc73778 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Payment/NonTokenizedPaymentMethodRequest.cs @@ -0,0 +1,21 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Attributes; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Payment; + +public class NonTokenizedPaymentMethodRequest +{ + [Required] + [NonTokenizedPaymentMethodTypeValidation] + public required string Type { get; set; } + + public NonTokenizedPaymentMethod ToDomain() + { + return Type switch + { + "accountCredit" => new NonTokenizedPaymentMethod { Type = NonTokenizablePaymentMethodType.AccountCredit }, + _ => throw new InvalidOperationException($"Invalid value for {nameof(NonTokenizedPaymentMethod)}.{nameof(NonTokenizedPaymentMethod.Type)}") + }; + } +} diff --git a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs index f540957a1a..2a54313421 100644 --- a/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs +++ b/src/Api/Billing/Models/Requests/Payment/TokenizedPaymentMethodRequest.cs @@ -1,31 +1,15 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; -using Bit.Api.Billing.Attributes; -using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Payment; -public class TokenizedPaymentMethodRequest +public class TokenizedPaymentMethodRequest : MinimalTokenizedPaymentMethodRequest { - [Required] - [PaymentMethodTypeValidation] - public required string Type { get; set; } - - [Required] - public required string Token { get; set; } - public MinimalBillingAddressRequest? BillingAddress { get; set; } - public (TokenizedPaymentMethod, BillingAddress?) ToDomain() + public new (TokenizedPaymentMethod, BillingAddress?) ToDomain() { - var paymentMethod = new TokenizedPaymentMethod - { - Type = TokenizablePaymentMethodTypeExtensions.From(Type), - Token = Token - }; - + var paymentMethod = base.ToDomain(); var billingAddress = BillingAddress?.ToDomain(); - return (paymentMethod, billingAddress); } } diff --git a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs index b958057f5b..0f9198fdad 100644 --- a/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/PremiumCloudHostedSubscriptionRequest.cs @@ -1,14 +1,13 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Payment.Models; namespace Bit.Api.Billing.Models.Requests.Premium; -public class PremiumCloudHostedSubscriptionRequest +public class PremiumCloudHostedSubscriptionRequest : IValidatableObject { - [Required] - public required MinimalTokenizedPaymentMethodRequest TokenizedPaymentMethod { get; set; } + public MinimalTokenizedPaymentMethodRequest? TokenizedPaymentMethod { get; set; } + public NonTokenizedPaymentMethodRequest? NonTokenizedPaymentMethod { get; set; } [Required] public required MinimalBillingAddressRequest BillingAddress { get; set; } @@ -16,11 +15,38 @@ public class PremiumCloudHostedSubscriptionRequest [Range(0, 99)] public short AdditionalStorageGb { get; set; } = 0; - public (TokenizedPaymentMethod, BillingAddress, short) ToDomain() + + public (PaymentMethod, BillingAddress, short) ToDomain() { - var paymentMethod = TokenizedPaymentMethod.ToDomain(); + // Check if TokenizedPaymentMethod or NonTokenizedPaymentMethod is provided. + var tokenizedPaymentMethod = TokenizedPaymentMethod?.ToDomain(); + var nonTokenizedPaymentMethod = NonTokenizedPaymentMethod?.ToDomain(); + + PaymentMethod paymentMethod = tokenizedPaymentMethod != null + ? tokenizedPaymentMethod + : nonTokenizedPaymentMethod!; + var billingAddress = BillingAddress.ToDomain(); return (paymentMethod, billingAddress, AdditionalStorageGb); } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (TokenizedPaymentMethod == null && NonTokenizedPaymentMethod == null) + { + yield return new ValidationResult( + "Either TokenizedPaymentMethod or NonTokenizedPaymentMethod must be provided.", + new[] { nameof(TokenizedPaymentMethod), nameof(NonTokenizedPaymentMethod) } + ); + } + + if (TokenizedPaymentMethod != null && NonTokenizedPaymentMethod != null) + { + yield return new ValidationResult( + "Only one of TokenizedPaymentMethod or NonTokenizedPaymentMethod can be provided.", + new[] { nameof(TokenizedPaymentMethod), nameof(NonTokenizedPaymentMethod) } + ); + } + } } diff --git a/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs b/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs deleted file mode 100644 index a3fda0fd6c..0000000000 --- a/src/Api/Billing/Models/Requests/PreviewTaxAmountForOrganizationTrialRequestBody.cs +++ /dev/null @@ -1,27 +0,0 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; -using Bit.Core.Billing.Enums; - -namespace Bit.Api.Billing.Models.Requests; - -public class PreviewTaxAmountForOrganizationTrialRequestBody -{ - [Required] - public PlanType PlanType { get; set; } - - [Required] - public ProductType ProductType { get; set; } - - [Required] public TaxInformationDTO TaxInformation { get; set; } = null!; - - public class TaxInformationDTO - { - [Required] - public string Country { get; set; } = null!; - - [Required] - public string PostalCode { get; set; } = null!; - - public string? TaxId { get; set; } - } -} diff --git a/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs b/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs new file mode 100644 index 0000000000..ac66270427 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Subscriptions/RestartSubscriptionRequest.cs @@ -0,0 +1,16 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Subscriptions; + +public class RestartSubscriptionRequest +{ + [Required] + public required MinimalTokenizedPaymentMethodRequest PaymentMethod { get; set; } + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (TokenizedPaymentMethod, BillingAddress) ToDomain() + => (PaymentMethod.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs new file mode 100644 index 0000000000..9233a53c85 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewOrganizationSubscriptionPlanChangeTaxRequest +{ + [Required] + public required OrganizationSubscriptionPlanChangeRequest Plan { get; set; } + + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (OrganizationSubscriptionPlanChange, BillingAddress) ToDomain() => + (Plan.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs new file mode 100644 index 0000000000..dcc5911f3d --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs @@ -0,0 +1,19 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewOrganizationSubscriptionPurchaseTaxRequest +{ + [Required] + public required OrganizationSubscriptionPurchaseRequest Purchase { get; set; } + + [Required] + public required CheckoutBillingAddressRequest BillingAddress { get; set; } + + public (OrganizationSubscriptionPurchase, BillingAddress) ToDomain() => + (Purchase.ToDomain(), BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs new file mode 100644 index 0000000000..ae96214ae3 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs @@ -0,0 +1,11 @@ +using Bit.Api.Billing.Models.Requests.Organizations; +using Bit.Core.Billing.Organizations.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public class PreviewOrganizationSubscriptionUpdateTaxRequest +{ + public required OrganizationSubscriptionUpdateRequest Update { get; set; } + + public OrganizationSubscriptionUpdate ToDomain() => Update.ToDomain(); +} diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs new file mode 100644 index 0000000000..76b8a5a444 --- /dev/null +++ b/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs @@ -0,0 +1,17 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.Tax; + +public record PreviewPremiumSubscriptionPurchaseTaxRequest +{ + [Required] + [Range(0, 99, ErrorMessage = "Additional storage must be between 0 and 99 GB.")] + public short AdditionalStorage { get; set; } + + [Required] + public required MinimalBillingAddressRequest BillingAddress { get; set; } + + public (short, BillingAddress) ToDomain() => (AdditionalStorage, BillingAddress.ToDomain()); +} diff --git a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs b/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs deleted file mode 100644 index a13f267c3b..0000000000 --- a/src/Api/Billing/Models/Responses/OrganizationMetadataResponse.cs +++ /dev/null @@ -1,31 +0,0 @@ -using Bit.Core.Billing.Organizations.Models; - -namespace Bit.Api.Billing.Models.Responses; - -public record OrganizationMetadataResponse( - bool IsEligibleForSelfHost, - bool IsManaged, - bool IsOnSecretsManagerStandalone, - bool IsSubscriptionUnpaid, - bool HasSubscription, - bool HasOpenInvoice, - bool IsSubscriptionCanceled, - DateTime? InvoiceDueDate, - DateTime? InvoiceCreatedDate, - DateTime? SubPeriodEndDate, - int OrganizationOccupiedSeats) -{ - public static OrganizationMetadataResponse From(OrganizationMetadata metadata) - => new( - metadata.IsEligibleForSelfHost, - metadata.IsManaged, - metadata.IsOnSecretsManagerStandalone, - metadata.IsSubscriptionUnpaid, - metadata.HasSubscription, - metadata.HasOpenInvoice, - metadata.IsSubscriptionCanceled, - metadata.InvoiceDueDate, - metadata.InvoiceCreatedDate, - metadata.SubPeriodEndDate, - metadata.OrganizationOccupiedSeats); -} diff --git a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs index e5b868af9a..4b78127240 100644 --- a/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs +++ b/src/Api/Billing/Models/Responses/ProviderSubscriptionResponse.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Tax.Models; @@ -10,7 +11,7 @@ namespace Bit.Api.Billing.Models.Responses; public record ProviderSubscriptionResponse( string Status, - DateTime CurrentPeriodEndDate, + DateTime? CurrentPeriodEndDate, decimal? DiscountPercentage, string CollectionMethod, IEnumerable Plans, @@ -51,10 +52,12 @@ public record ProviderSubscriptionResponse( var accountCredit = Convert.ToDecimal(subscription.Customer?.Balance) * -1 / 100; + var discount = subscription.Customer?.Discount ?? subscription.Discounts?.FirstOrDefault(); + return new ProviderSubscriptionResponse( subscription.Status, - subscription.CurrentPeriodEnd, - subscription.Customer?.Discount?.Coupon?.PercentOff, + subscription.GetCurrentPeriodEnd(), + discount?.Coupon?.PercentOff, subscription.CollectionMethod, providerPlanResponses, accountCredit, diff --git a/src/Api/Controllers/MiscController.cs b/src/Api/Controllers/MiscController.cs deleted file mode 100644 index 6f23a27fbf..0000000000 --- a/src/Api/Controllers/MiscController.cs +++ /dev/null @@ -1,45 +0,0 @@ -using Bit.Api.Models.Request; -using Bit.Core.Settings; -using Bit.Core.Utilities; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; -using Stripe; - -namespace Bit.Api.Controllers; - -public class MiscController : Controller -{ - private readonly BitPayClient _bitPayClient; - private readonly GlobalSettings _globalSettings; - - public MiscController( - BitPayClient bitPayClient, - GlobalSettings globalSettings) - { - _bitPayClient = bitPayClient; - _globalSettings = globalSettings; - } - - [Authorize("Application")] - [HttpPost("~/bitpay-invoice")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostBitPayInvoice([FromBody] BitPayInvoiceRequestModel model) - { - var invoice = await _bitPayClient.CreateInvoiceAsync(model.ToBitpayInvoice(_globalSettings)); - return invoice.Url; - } - - [Authorize("Application")] - [HttpPost("~/setup-payment")] - [SelfHosted(NotSelfHostedOnly = true)] - public async Task PostSetupPayment() - { - var options = new SetupIntentCreateOptions - { - Usage = "off_session" - }; - var service = new SetupIntentService(); - var setupIntent = await service.CreateAsync(options); - return setupIntent.ClientSecret; - } -} diff --git a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs index 198438201c..6865bc06da 100644 --- a/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs +++ b/src/Api/Controllers/SelfHosted/SelfHostedOrganizationSponsorshipsController.cs @@ -55,19 +55,6 @@ public class SelfHostedOrganizationSponsorshipsController : Controller [HttpPost("{sponsoringOrgId}/families-for-enterprise")] public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) { - if (!_featureService.IsEnabled(Bit.Core.FeatureFlagKeys.PM17772_AdminInitiatedSponsorships)) - { - if (model.IsAdminInitiated.GetValueOrDefault()) - { - throw new BadRequestException(); - } - - if (!string.IsNullOrWhiteSpace(model.Notes)) - { - model.Notes = null; - } - } - await _offerSponsorshipCommand.CreateSponsorshipAsync( await _organizationRepository.GetByIdAsync(sponsoringOrgId), await _organizationUserRepository.GetByOrganizationAsync(sponsoringOrgId, _currentContext.UserId ?? default), diff --git a/src/Api/Controllers/UsersController.cs b/src/Api/Controllers/UsersController.cs deleted file mode 100644 index 4dfd047d37..0000000000 --- a/src/Api/Controllers/UsersController.cs +++ /dev/null @@ -1,33 +0,0 @@ -using Bit.Api.Models.Response; -using Bit.Core.Exceptions; -using Bit.Core.Repositories; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Controllers; - -[Route("users")] -[Authorize("Application")] -public class UsersController : Controller -{ - private readonly IUserRepository _userRepository; - - public UsersController( - IUserRepository userRepository) - { - _userRepository = userRepository; - } - - [HttpGet("{id}/public-key")] - public async Task Get(string id) - { - var guidId = new Guid(id); - var key = await _userRepository.GetPublicKeyAsync(guidId); - if (key == null) - { - throw new NotFoundException(); - } - - return new UserKeyResponseModel(guidId, key); - } -} diff --git a/src/Api/Dirt/Controllers/OrganizationReportsController.cs b/src/Api/Dirt/Controllers/OrganizationReportsController.cs index bcd64b0bdf..fc9a1b2d84 100644 --- a/src/Api/Dirt/Controllers/OrganizationReportsController.cs +++ b/src/Api/Dirt/Controllers/OrganizationReportsController.cs @@ -1,4 +1,5 @@ -using Bit.Core.Context; +using Bit.Api.Dirt.Models.Response; +using Bit.Core.Context; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Exceptions; @@ -61,8 +62,9 @@ public class OrganizationReportsController : Controller } var latestReport = await _getOrganizationReportQuery.GetLatestOrganizationReportAsync(organizationId); + var response = latestReport == null ? null : new OrganizationReportResponseModel(latestReport); - return Ok(latestReport); + return Ok(response); } [HttpGet("{organizationId}/{reportId}")] @@ -102,7 +104,8 @@ public class OrganizationReportsController : Controller } var report = await _addOrganizationReportCommand.AddOrganizationReportAsync(request); - return Ok(report); + var response = report == null ? null : new OrganizationReportResponseModel(report); + return Ok(response); } [HttpPatch("{organizationId}/{reportId}")] @@ -119,7 +122,8 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportCommand.UpdateOrganizationReportAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + return Ok(response); } #endregion @@ -182,10 +186,10 @@ public class OrganizationReportsController : Controller { throw new BadRequestException("Report ID in the request body must match the route parameter"); } - var updatedReport = await _updateOrganizationReportSummaryCommand.UpdateOrganizationReportSummaryAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - return Ok(updatedReport); + return Ok(response); } #endregion @@ -228,7 +232,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportDataCommand.UpdateOrganizationReportDataAsync(request); - return Ok(updatedReport); + var response = new OrganizationReportResponseModel(updatedReport); + + return Ok(response); } #endregion @@ -265,7 +271,6 @@ public class OrganizationReportsController : Controller { try { - if (!await _currentContext.AccessReports(organizationId)) { throw new NotFoundException(); @@ -282,10 +287,9 @@ public class OrganizationReportsController : Controller } var updatedReport = await _updateOrganizationReportApplicationDataCommand.UpdateOrganizationReportApplicationDataAsync(request); + var response = new OrganizationReportResponseModel(updatedReport); - - - return Ok(updatedReport); + return Ok(response); } catch (Exception ex) when (!(ex is BadRequestException || ex is NotFoundException)) { diff --git a/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs new file mode 100644 index 0000000000..e477e5b806 --- /dev/null +++ b/src/Api/Dirt/Models/Response/OrganizationReportResponseModel.cs @@ -0,0 +1,38 @@ +using Bit.Core.Dirt.Entities; + +namespace Bit.Api.Dirt.Models.Response; + +public class OrganizationReportResponseModel +{ + public Guid Id { get; set; } + public Guid OrganizationId { get; set; } + public string? ReportData { get; set; } + public string? ContentEncryptionKey { get; set; } + public string? SummaryData { get; set; } + public string? ApplicationData { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public DateTime? CreationDate { get; set; } = null; + public DateTime? RevisionDate { get; set; } = null; + + public OrganizationReportResponseModel(OrganizationReport organizationReport) + { + if (organizationReport == null) + { + return; + } + + Id = organizationReport.Id; + OrganizationId = organizationReport.OrganizationId; + ReportData = organizationReport.ReportData; + ContentEncryptionKey = organizationReport.ContentEncryptionKey; + SummaryData = organizationReport.SummaryData; + ApplicationData = organizationReport.ApplicationData; + PasswordCount = organizationReport.PasswordCount; + PasswordAtRiskCount = organizationReport.PasswordAtRiskCount; + MemberCount = organizationReport.MemberCount; + CreationDate = organizationReport.CreationDate; + RevisionDate = organizationReport.RevisionDate; + } +} diff --git a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs index 9fc0e9a75a..7968970048 100644 --- a/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs +++ b/src/Api/KeyManagement/Controllers/AccountsKeyManagementController.cs @@ -106,8 +106,7 @@ public class AccountsKeyManagementController : Controller { OldMasterKeyAuthenticationHash = model.OldMasterKeyAuthenticationHash, - UserKeyEncryptedAccountPrivateKey = model.AccountKeys.UserKeyEncryptedAccountPrivateKey, - AccountPublicKey = model.AccountKeys.AccountPublicKey, + AccountKeys = model.AccountKeys.ToAccountKeysData(), MasterPasswordUnlockData = model.AccountUnlockData.MasterPasswordUnlockData.ToUnlockData(), EmergencyAccesses = await _emergencyAccessValidator.ValidateAsync(user, model.AccountUnlockData.EmergencyAccessUnlockData), diff --git a/src/Api/KeyManagement/Controllers/UsersController.cs b/src/Api/KeyManagement/Controllers/UsersController.cs new file mode 100644 index 0000000000..cfd2f8ee29 --- /dev/null +++ b/src/Api/KeyManagement/Controllers/UsersController.cs @@ -0,0 +1,39 @@ +using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.Repositories; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using UserKeyResponseModel = Bit.Api.Models.Response.UserKeyResponseModel; + + +namespace Bit.Api.KeyManagement.Controllers; + +[Route("users")] +[Authorize("Application")] +public class UsersController : Controller +{ + private readonly IUserRepository _userRepository; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; + + public UsersController(IUserRepository userRepository, IUserAccountKeysQuery userAccountKeysQuery) + { + _userRepository = userRepository; + _userAccountKeysQuery = userAccountKeysQuery; + } + + [HttpGet("{id}/public-key")] + public async Task GetPublicKeyAsync([FromRoute] Guid id) + { + var key = await _userRepository.GetPublicKeyAsync(id) ?? throw new NotFoundException(); + return new UserKeyResponseModel(id, key); + } + + [HttpGet("{id}/keys")] + public async Task GetAccountKeysAsync([FromRoute] Guid id) + { + var user = await _userRepository.GetByIdAsync(id) ?? throw new NotFoundException(); + var accountKeys = await _userAccountKeysQuery.Run(user) ?? throw new NotFoundException("User account keys not found."); + return new PublicKeysResponseModel(accountKeys); + } +} diff --git a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs b/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs index 7c7de4d210..b64e826911 100644 --- a/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/AccountKeysRequestModel.cs @@ -1,4 +1,5 @@ -#nullable enable +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; namespace Bit.Api.KeyManagement.Models.Requests; @@ -7,4 +8,44 @@ public class AccountKeysRequestModel { [EncryptedString] public required string UserKeyEncryptedAccountPrivateKey { get; set; } public required string AccountPublicKey { get; set; } + + public PublicKeyEncryptionKeyPairRequestModel? PublicKeyEncryptionKeyPair { get; set; } + public SignatureKeyPairRequestModel? SignatureKeyPair { get; set; } + public SecurityStateModel? SecurityState { get; set; } + + public UserAccountKeysData ToAccountKeysData() + { + // This will be cleaned up, after a compatibility period, at which point PublicKeyEncryptionKeyPair and SignatureKeyPair will be required. + // TODO: https://bitwarden.atlassian.net/browse/PM-23751 + if (PublicKeyEncryptionKeyPair == null) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData + ( + UserKeyEncryptedAccountPrivateKey, + AccountPublicKey + ), + }; + } + else + { + if (SignatureKeyPair == null || SecurityState == null) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = PublicKeyEncryptionKeyPair.ToPublicKeyEncryptionKeyPairData(), + }; + } + else + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = PublicKeyEncryptionKeyPair.ToPublicKeyEncryptionKeyPairData(), + SignatureKeyPairData = SignatureKeyPair.ToSignatureKeyPairData(), + SecurityStateData = SecurityState.ToSecurityState() + }; + } + } + } } diff --git a/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs b/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs index 495d13cccd..767cfd3f9b 100644 --- a/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/KeyRegenerationRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs b/src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs new file mode 100644 index 0000000000..24c1e6a946 --- /dev/null +++ b/src/Api/KeyManagement/Models/Requests/PublicKeyEncryptionKeyPairRequestModel.cs @@ -0,0 +1,20 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Api.KeyManagement.Models.Requests; + +public class PublicKeyEncryptionKeyPairRequestModel +{ + [EncryptedString] public required string WrappedPrivateKey { get; set; } + public required string PublicKey { get; set; } + public string? SignedPublicKey { get; set; } + + public PublicKeyEncryptionKeyPairData ToPublicKeyEncryptionKeyPairData() + { + return new PublicKeyEncryptionKeyPairData( + WrappedPrivateKey, + PublicKey, + SignedPublicKey + ); + } +} diff --git a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs index b0b19e2bd3..02780b015a 100644 --- a/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/RotateAccountKeysAndDataRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs b/src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs new file mode 100644 index 0000000000..3cdb4f53f1 --- /dev/null +++ b/src/Api/KeyManagement/Models/Requests/SignatureKeyPairRequestModel.cs @@ -0,0 +1,28 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + +namespace Bit.Api.KeyManagement.Models.Requests; + +public class SignatureKeyPairRequestModel +{ + public required string SignatureAlgorithm { get; set; } + [EncryptedString] public required string WrappedSigningKey { get; set; } + public required string VerifyingKey { get; set; } + + public SignatureKeyPairData ToSignatureKeyPairData() + { + if (SignatureAlgorithm != "ed25519") + { + throw new ArgumentException( + $"Unsupported signature algorithm: {SignatureAlgorithm}" + ); + } + var algorithm = Core.KeyManagement.Enums.SignatureAlgorithm.Ed25519; + + return new SignatureKeyPairData( + algorithm, + WrappedSigningKey, + VerifyingKey + ); + } +} diff --git a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs index 3af944110c..01e5dd7017 100644 --- a/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UnlockDataRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.Auth.Models.Request.WebAuthn; diff --git a/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs b/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs index f854d82bcc..df922fcda0 100644 --- a/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs +++ b/src/Api/KeyManagement/Models/Requests/UserDataRequestModel.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Api.Tools.Models.Request; +using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; namespace Bit.Api.KeyManagement.Models.Requests; diff --git a/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs b/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs index 9c7efe0fbe..e92be11cd2 100644 --- a/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs +++ b/src/Api/KeyManagement/Validators/WebAuthnLoginKeyRotationValidator.cs @@ -1,4 +1,5 @@ using Bit.Api.Auth.Models.Request.WebAuthn; +using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Entities; @@ -6,7 +7,13 @@ using Bit.Core.Exceptions; namespace Bit.Api.KeyManagement.Validators; -public class WebAuthnLoginKeyRotationValidator : IRotationValidator, IEnumerable> +/// +/// Validates WebAuthn credentials during key rotation. Only processes credentials that have PRF enabled +/// and have encrypted user, public, and private keys. Ensures all such credentials are included +/// in the rotation request with the required encrypted keys. +/// +public class WebAuthnLoginKeyRotationValidator : IRotationValidator, + IEnumerable> { private readonly IWebAuthnCredentialRepository _webAuthnCredentialRepository; @@ -15,24 +22,20 @@ public class WebAuthnLoginKeyRotationValidator : IRotationValidator> ValidateAsync(User user, IEnumerable keysToRotate) + public async Task> ValidateAsync(User user, + IEnumerable keysToRotate) { var result = new List(); - var existing = await _webAuthnCredentialRepository.GetManyByUserIdAsync(user.Id); - if (existing == null) + var validCredentials = (await _webAuthnCredentialRepository.GetManyByUserIdAsync(user.Id)) + .Where(credential => credential.GetPrfStatus() == WebAuthnPrfStatus.Enabled).ToList(); + if (validCredentials.Count == 0) { return result; } - var validCredentials = existing.Where(credential => credential.SupportsPrf); - if (!validCredentials.Any()) + foreach (var webAuthnCredential in validCredentials) { - return result; - } - - foreach (var ea in validCredentials) - { - var keyToRotate = keysToRotate.FirstOrDefault(c => c.Id == ea.Id); + var keyToRotate = keysToRotate.FirstOrDefault(c => c.Id == webAuthnCredential.Id); if (keyToRotate == null) { throw new BadRequestException("All existing webauthn prf keys must be included in the rotation."); @@ -42,6 +45,7 @@ public class WebAuthnLoginKeyRotationValidator : IRotationValidator Validate(ValidationContext validationContext) - { - if (!UserId.HasValue && !OrganizationId.HasValue && !ProviderId.HasValue) - { - yield return new ValidationResult("User, Organization or Provider is required."); - } - } -} diff --git a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs b/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs deleted file mode 100644 index 71f6873800..0000000000 --- a/src/Api/Models/Request/Organizations/OrganizationVerifyBankRequestModel.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System.ComponentModel.DataAnnotations; - -namespace Bit.Api.Models.Request.Organizations; - -public class OrganizationVerifyBankRequestModel -{ - [Required] - [Range(1, 99)] - public int? Amount1 { get; set; } - [Required] - [Range(1, 99)] - public int? Amount2 { get; set; } -} diff --git a/src/Api/Models/Response/ProfileResponseModel.cs b/src/Api/Models/Response/ProfileResponseModel.cs index cbdfaf0f16..30ba05b6a6 100644 --- a/src/Api/Models/Response/ProfileResponseModel.cs +++ b/src/Api/Models/Response/ProfileResponseModel.cs @@ -5,6 +5,8 @@ using Bit.Api.AdminConsole.Models.Response; using Bit.Api.AdminConsole.Models.Response.Providers; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; using Bit.Core.Models.Data.Organizations.OrganizationUsers; @@ -13,6 +15,7 @@ namespace Bit.Api.Models.Response; public class ProfileResponseModel : ResponseModel { public ProfileResponseModel(User user, + UserAccountKeysData userAccountKeysData, IEnumerable organizationsUserDetails, IEnumerable providerUserDetails, IEnumerable providerUserOrganizationDetails, @@ -35,6 +38,7 @@ public class ProfileResponseModel : ResponseModel TwoFactorEnabled = twoFactorEnabled; Key = user.Key; PrivateKey = user.PrivateKey; + AccountKeys = userAccountKeysData != null ? new PrivateKeysResponseModel(userAccountKeysData) : null; SecurityStamp = user.SecurityStamp; ForcePasswordReset = user.ForcePasswordReset; UsesKeyConnector = user.UsesKeyConnector; @@ -60,7 +64,9 @@ public class ProfileResponseModel : ResponseModel public string Culture { get; set; } public bool TwoFactorEnabled { get; set; } public string Key { get; set; } + [Obsolete("Use AccountKeys instead.")] public string PrivateKey { get; set; } + public PrivateKeysResponseModel AccountKeys { get; set; } public string SecurityStamp { get; set; } public bool ForcePasswordReset { get; set; } public bool UsesKeyConnector { get; set; } diff --git a/src/Api/Models/Response/SubscriptionResponseModel.cs b/src/Api/Models/Response/SubscriptionResponseModel.cs index 7038bee2a7..29a47e160c 100644 --- a/src/Api/Models/Response/SubscriptionResponseModel.cs +++ b/src/Api/Models/Response/SubscriptionResponseModel.cs @@ -1,6 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models.Business; using Bit.Core.Entities; using Bit.Core.Models.Api; @@ -11,7 +9,17 @@ namespace Bit.Api.Models.Response; public class SubscriptionResponseModel : ResponseModel { - public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license) + + /// The user entity containing storage and premium subscription information + /// Subscription information retrieved from the payment provider (Stripe/Braintree) + /// The user's license containing expiration and feature entitlements + /// + /// Whether to include discount information in the response. + /// Set to true when the PM23341_Milestone_2 feature flag is enabled AND + /// you want to expose Milestone 2 discount information to the client. + /// The discount will only be included if it matches the specific Milestone 2 coupon ID. + /// + public SubscriptionResponseModel(User user, SubscriptionInfo subscription, UserLicense license, bool includeMilestone2Discount = false) : base("subscription") { Subscription = subscription.Subscription != null ? new BillingSubscription(subscription.Subscription) : null; @@ -22,9 +30,14 @@ public class SubscriptionResponseModel : ResponseModel MaxStorageGb = user.MaxStorageGb; License = license; Expiration = License.Expires; + + // Only display the Milestone 2 subscription discount on the subscription page. + CustomerDiscount = ShouldIncludeMilestone2Discount(includeMilestone2Discount, subscription.CustomerDiscount) + ? new BillingCustomerDiscount(subscription.CustomerDiscount!) + : null; } - public SubscriptionResponseModel(User user, UserLicense license = null) + public SubscriptionResponseModel(User user, UserLicense? license = null) : base("subscription") { StorageName = user.Storage.HasValue ? CoreHelpers.ReadableBytesSize(user.Storage.Value) : null; @@ -38,21 +51,109 @@ public class SubscriptionResponseModel : ResponseModel } } - public string StorageName { get; set; } + public string? StorageName { get; set; } public double? StorageGb { get; set; } public short? MaxStorageGb { get; set; } - public BillingSubscriptionUpcomingInvoice UpcomingInvoice { get; set; } - public BillingSubscription Subscription { get; set; } - public UserLicense License { get; set; } + public BillingSubscriptionUpcomingInvoice? UpcomingInvoice { get; set; } + public BillingSubscription? Subscription { get; set; } + /// + /// Customer discount information from Stripe for the Milestone 2 subscription discount. + /// Only includes the specific Milestone 2 coupon (cm3nHfO1) when it's a perpetual discount (no expiration). + /// This is for display purposes only and does not affect Stripe's automatic discount application. + /// Other discounts may still apply in Stripe billing but are not included in this response. + /// + /// Null when: + /// - The PM23341_Milestone_2 feature flag is disabled + /// - There is no active discount + /// - The discount coupon ID doesn't match the Milestone 2 coupon (cm3nHfO1) + /// - The instance is self-hosted + /// + /// + public BillingCustomerDiscount? CustomerDiscount { get; set; } + public UserLicense? License { get; set; } public DateTime? Expiration { get; set; } + + /// + /// Determines whether the Milestone 2 discount should be included in the response. + /// + /// Whether the feature flag is enabled and discount should be considered. + /// The customer discount from subscription info, if any. + /// True if the discount should be included; false otherwise. + private static bool ShouldIncludeMilestone2Discount( + bool includeMilestone2Discount, + SubscriptionInfo.BillingCustomerDiscount? customerDiscount) + { + return includeMilestone2Discount && + customerDiscount != null && + customerDiscount.Id == StripeConstants.CouponIDs.Milestone2SubscriptionDiscount && + customerDiscount.Active; + } } -public class BillingCustomerDiscount(SubscriptionInfo.BillingCustomerDiscount discount) +/// +/// Customer discount information from Stripe billing. +/// +public class BillingCustomerDiscount { - public string Id { get; } = discount.Id; - public bool Active { get; } = discount.Active; - public decimal? PercentOff { get; } = discount.PercentOff; - public List AppliesTo { get; } = discount.AppliesTo; + /// + /// The Stripe coupon ID (e.g., "cm3nHfO1"). + /// + public string? Id { get; } + + /// + /// Whether the discount is a recurring/perpetual discount with no expiration date. + /// + /// This property is true only when the discount has no end date, meaning it applies + /// indefinitely to all future renewals. This is a product decision for Milestone 2 + /// to only display perpetual discounts in the UI. + /// + /// + /// Note: This does NOT indicate whether the discount is "currently active" in the billing sense. + /// A discount with a future end date is functionally active and will be applied by Stripe, + /// but this property will be false because it has an expiration date. + /// + /// + public bool Active { get; } + + /// + /// Percentage discount applied to the subscription (e.g., 20.0 for 20% off). + /// Null if this is an amount-based discount. + /// + public decimal? PercentOff { get; } + + /// + /// Fixed amount discount in USD (e.g., 14.00 for $14 off). + /// Converted from Stripe's cent-based values (1400 cents → $14.00). + /// Null if this is a percentage-based discount. + /// Note: Stripe stores amounts in the smallest currency unit. This value is always in USD. + /// + public decimal? AmountOff { get; } + + /// + /// List of Stripe product IDs that this discount applies to (e.g., ["prod_premium", "prod_families"]). + /// + /// Null: discount applies to all products with no restrictions (AppliesTo not specified in Stripe). + /// Empty list: discount restricted to zero products (edge case - AppliesTo.Products = [] in Stripe). + /// Non-empty list: discount applies only to the specified product IDs. + /// + /// + public IReadOnlyList? AppliesTo { get; } + + /// + /// Creates a BillingCustomerDiscount from a SubscriptionInfo.BillingCustomerDiscount. + /// + /// The discount to convert. Must not be null. + /// Thrown when discount is null. + public BillingCustomerDiscount(SubscriptionInfo.BillingCustomerDiscount discount) + { + ArgumentNullException.ThrowIfNull(discount); + + Id = discount.Id; + Active = discount.Active; + PercentOff = discount.PercentOff; + AmountOff = discount.AmountOff; + AppliesTo = discount.AppliesTo; + } } public class BillingSubscription @@ -83,10 +184,10 @@ public class BillingSubscription public DateTime? PeriodEndDate { get; set; } public DateTime? CancelledDate { get; set; } public bool CancelAtEndDate { get; set; } - public string Status { get; set; } + public string? Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); - public string CollectionMethod { get; set; } + public string? CollectionMethod { get; set; } public DateTime? SuspensionDate { get; set; } public DateTime? UnpaidPeriodEndDate { get; set; } public int? GracePeriod { get; set; } @@ -104,11 +205,11 @@ public class BillingSubscription AddonSubscriptionItem = item.AddonSubscriptionItem; } - public string ProductId { get; set; } - public string Name { get; set; } + public string? ProductId { get; set; } + public string? Name { get; set; } public decimal Amount { get; set; } public int Quantity { get; set; } - public string Interval { get; set; } + public string? Interval { get; set; } public bool SponsoredSubscriptionItem { get; set; } public bool AddonSubscriptionItem { get; set; } } diff --git a/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs b/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs index cd65a7cdf8..ad5d5e092b 100644 --- a/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs +++ b/src/Api/SecretsManager/Controllers/AccessPoliciesController.cs @@ -29,6 +29,7 @@ public class AccessPoliciesController : Controller private readonly IServiceAccountRepository _serviceAccountRepository; private readonly IUpdateServiceAccountGrantedPoliciesCommand _updateServiceAccountGrantedPoliciesCommand; private readonly IUserService _userService; + private readonly IEventService _eventService; private readonly IProjectServiceAccountsAccessPoliciesUpdatesQuery _projectServiceAccountsAccessPoliciesUpdatesQuery; private readonly IUpdateProjectServiceAccountsAccessPoliciesCommand @@ -47,7 +48,8 @@ public class AccessPoliciesController : Controller IServiceAccountGrantedPolicyUpdatesQuery serviceAccountGrantedPolicyUpdatesQuery, IProjectServiceAccountsAccessPoliciesUpdatesQuery projectServiceAccountsAccessPoliciesUpdatesQuery, IUpdateServiceAccountGrantedPoliciesCommand updateServiceAccountGrantedPoliciesCommand, - IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand) + IUpdateProjectServiceAccountsAccessPoliciesCommand updateProjectServiceAccountsAccessPoliciesCommand, + IEventService eventService) { _authorizationService = authorizationService; _userService = userService; @@ -61,6 +63,7 @@ public class AccessPoliciesController : Controller _serviceAccountGrantedPolicyUpdatesQuery = serviceAccountGrantedPolicyUpdatesQuery; _projectServiceAccountsAccessPoliciesUpdatesQuery = projectServiceAccountsAccessPoliciesUpdatesQuery; _updateProjectServiceAccountsAccessPoliciesCommand = updateProjectServiceAccountsAccessPoliciesCommand; + _eventService = eventService; } [HttpGet("/organizations/{id}/access-policies/people/potential-grantees")] @@ -186,7 +189,9 @@ public class AccessPoliciesController : Controller } var userId = _userService.GetProperUserId(User)!.Value; + var currentPolicies = await _accessPolicyRepository.GetPeoplePoliciesByGrantedServiceAccountIdAsync(peopleAccessPolicies.Id, userId); var results = await _accessPolicyRepository.ReplaceServiceAccountPeopleAsync(peopleAccessPolicies, userId); + await LogAccessPolicyServiceAccountChanges(currentPolicies, results, userId); return new ServiceAccountPeopleAccessPoliciesResponseModel(results, userId); } @@ -336,4 +341,39 @@ public class AccessPoliciesController : Controller userId, accessClient); return new ServiceAccountGrantedPoliciesPermissionDetailsResponseModel(results); } + + public async Task LogAccessPolicyServiceAccountChanges(IEnumerable currentPolicies, IEnumerable updatedPolicies, Guid userId) + { + foreach (var current in currentPolicies.OfType()) + { + if (!updatedPolicies.Any(r => r.Id == current.Id)) + { + await _eventService.LogServiceAccountGroupEventAsync(userId, current, EventType.ServiceAccount_GroupRemoved, _currentContext.IdentityClientType); + } + } + + foreach (var policy in updatedPolicies.OfType()) + { + if (!currentPolicies.Any(e => e.Id == policy.Id)) + { + await _eventService.LogServiceAccountGroupEventAsync(userId, policy, EventType.ServiceAccount_GroupAdded, _currentContext.IdentityClientType); + } + } + + foreach (var current in currentPolicies.OfType()) + { + if (!updatedPolicies.Any(r => r.Id == current.Id)) + { + await _eventService.LogServiceAccountPeopleEventAsync(userId, current, EventType.ServiceAccount_UserRemoved, _currentContext.IdentityClientType); + } + } + + foreach (var policy in updatedPolicies.OfType()) + { + if (!currentPolicies.Any(e => e.Id == policy.Id)) + { + await _eventService.LogServiceAccountPeopleEventAsync(userId, policy, EventType.ServiceAccount_UserAdded, _currentContext.IdentityClientType); + } + } + } } diff --git a/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs b/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs index 499c496cc9..0afdc3a1bf 100644 --- a/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs +++ b/src/Api/SecretsManager/Controllers/ServiceAccountsController.cs @@ -42,6 +42,8 @@ public class ServiceAccountsController : Controller private readonly IDeleteServiceAccountsCommand _deleteServiceAccountsCommand; private readonly IRevokeAccessTokensCommand _revokeAccessTokensCommand; private readonly IPricingClient _pricingClient; + private readonly IEventService _eventService; + private readonly IOrganizationUserRepository _organizationUserRepository; public ServiceAccountsController( ICurrentContext currentContext, @@ -58,7 +60,9 @@ public class ServiceAccountsController : Controller IUpdateServiceAccountCommand updateServiceAccountCommand, IDeleteServiceAccountsCommand deleteServiceAccountsCommand, IRevokeAccessTokensCommand revokeAccessTokensCommand, - IPricingClient pricingClient) + IPricingClient pricingClient, + IEventService eventService, + IOrganizationUserRepository organizationUserRepository) { _currentContext = currentContext; _userService = userService; @@ -75,6 +79,8 @@ public class ServiceAccountsController : Controller _pricingClient = pricingClient; _createAccessTokenCommand = createAccessTokenCommand; _updateSecretsManagerSubscriptionCommand = updateSecretsManagerSubscriptionCommand; + _eventService = eventService; + _organizationUserRepository = organizationUserRepository; } [HttpGet("/organizations/{organizationId}/service-accounts")] @@ -139,8 +145,15 @@ public class ServiceAccountsController : Controller } var userId = _userService.GetProperUserId(User).Value; + var result = - await _createServiceAccountCommand.CreateAsync(createRequest.ToServiceAccount(organizationId), userId); + await _createServiceAccountCommand.CreateAsync(serviceAccount, userId); + + if (result != null) + { + await _eventService.LogServiceAccountEventAsync(userId, [serviceAccount], EventType.ServiceAccount_Created, _currentContext.IdentityClientType); + } + return new ServiceAccountResponseModel(result); } @@ -197,6 +210,9 @@ public class ServiceAccountsController : Controller } await _deleteServiceAccountsCommand.DeleteServiceAccounts(serviceAccountsToDelete); + var userId = _userService.GetProperUserId(User)!.Value; + await _eventService.LogServiceAccountEventAsync(userId, serviceAccountsToDelete, EventType.ServiceAccount_Deleted, _currentContext.IdentityClientType); + var responses = results.Select(r => new BulkDeleteResponseModel(r.ServiceAccount.Id, r.Error)); return new ListResponseModel(responses); } diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index cc50a1b362..0967b4f662 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -94,9 +94,6 @@ public class Startup services.AddMemoryCache(); services.AddDistributedCache(globalSettings); - // BitPay - services.AddSingleton(); - if (!globalSettings.SelfHosted) { services.AddIpRateLimiting(globalSettings); @@ -229,8 +226,9 @@ public class Startup services.AddHostedService(); } - // Add SlackService for OAuth API requests - if configured + // Add Slack / Teams Services for OAuth API requests - if configured services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); } public void Configure( @@ -325,6 +323,6 @@ public class Startup } // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/src/Api/Tools/Controllers/ImportCiphersController.cs b/src/Api/Tools/Controllers/ImportCiphersController.cs index 88028420b7..8b3ec5e26c 100644 --- a/src/Api/Tools/Controllers/ImportCiphersController.cs +++ b/src/Api/Tools/Controllers/ImportCiphersController.cs @@ -74,10 +74,14 @@ public class ImportCiphersController : Controller throw new BadRequestException("You cannot import this much data at once."); } + if (model.Ciphers.Any(c => c.ArchivedDate.HasValue)) + { + throw new BadRequestException("You cannot import archived items into an organization."); + } + var orgId = new Guid(organizationId); var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); - //An User is allowed to import if CanCreate Collections or has AccessToImportExport var authorized = await CheckOrgImportPermission(collections, orgId); if (!authorized) @@ -156,7 +160,7 @@ public class ImportCiphersController : Controller if (existingCollections.Any() && (await _authorizationService.AuthorizeAsync(User, existingCollections, BulkCollectionOperations.ImportCiphers)).Succeeded) { return true; - }; + } return false; } diff --git a/src/Api/Tools/Controllers/OrganizationExportController.cs b/src/Api/Tools/Controllers/OrganizationExportController.cs index b1925dd3cf..dd039bc4a5 100644 --- a/src/Api/Tools/Controllers/OrganizationExportController.cs +++ b/src/Api/Tools/Controllers/OrganizationExportController.cs @@ -1,5 +1,6 @@ using Bit.Api.Tools.Authorization; using Bit.Api.Tools.Models.Response; +using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Shared.Authorization; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -20,19 +21,22 @@ public class OrganizationExportController : Controller private readonly IAuthorizationService _authorizationService; private readonly IOrganizationCiphersQuery _organizationCiphersQuery; private readonly ICollectionRepository _collectionRepository; + private readonly IFeatureService _featureService; public OrganizationExportController( IUserService userService, GlobalSettings globalSettings, IAuthorizationService authorizationService, IOrganizationCiphersQuery organizationCiphersQuery, - ICollectionRepository collectionRepository) + ICollectionRepository collectionRepository, + IFeatureService featureService) { _userService = userService; _globalSettings = globalSettings; _authorizationService = authorizationService; _organizationCiphersQuery = organizationCiphersQuery; _collectionRepository = collectionRepository; + _featureService = featureService; } [HttpGet("export")] @@ -40,23 +44,47 @@ public class OrganizationExportController : Controller { var canExportAll = await _authorizationService.AuthorizeAsync(User, new OrganizationScope(organizationId), VaultExportOperations.ExportWholeVault); - if (canExportAll.Succeeded) - { - var allOrganizationCiphers = await _organizationCiphersQuery.GetAllOrganizationCiphers(organizationId); - var allCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); - return Ok(new OrganizationExportResponseModel(allOrganizationCiphers, allCollections, _globalSettings)); - } - var canExportManaged = await _authorizationService.AuthorizeAsync(User, new OrganizationScope(organizationId), VaultExportOperations.ExportManagedCollections); + var createDefaultLocationEnabled = _featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation); + + if (canExportAll.Succeeded) + { + if (createDefaultLocationEnabled) + { + var allOrganizationCiphers = + await _organizationCiphersQuery.GetAllOrganizationCiphersExcludingDefaultUserCollections( + organizationId); + + var allCollections = await _collectionRepository + .GetManySharedCollectionsByOrganizationIdAsync( + organizationId); + + + return Ok(new OrganizationExportResponseModel(allOrganizationCiphers, allCollections, + _globalSettings)); + } + else + { + var allOrganizationCiphers = await _organizationCiphersQuery.GetAllOrganizationCiphers(organizationId); + + var allCollections = await _collectionRepository.GetManyByOrganizationIdAsync(organizationId); + + return Ok(new OrganizationExportResponseModel(allOrganizationCiphers, allCollections, + _globalSettings)); + } + } + if (canExportManaged.Succeeded) { var userId = _userService.GetProperUserId(User)!.Value; var allUserCollections = await _collectionRepository.GetManyByUserIdAsync(userId); - var managedOrgCollections = allUserCollections.Where(c => c.OrganizationId == organizationId && c.Manage).ToList(); - var managedCiphers = - await _organizationCiphersQuery.GetOrganizationCiphersByCollectionIds(organizationId, managedOrgCollections.Select(c => c.Id)); + var managedOrgCollections = + allUserCollections.Where(c => c.OrganizationId == organizationId && c.Manage).ToList(); + + var managedCiphers = await _organizationCiphersQuery.GetOrganizationCiphersByCollectionIds(organizationId, + managedOrgCollections.Select(c => c.Id)); return Ok(new OrganizationExportResponseModel(managedCiphers, managedOrgCollections, _globalSettings)); } diff --git a/src/Api/Tools/Controllers/SendsController.cs b/src/Api/Tools/Controllers/SendsController.cs index b3e16bc0be..c4bf4595f3 100644 --- a/src/Api/Tools/Controllers/SendsController.cs +++ b/src/Api/Tools/Controllers/SendsController.cs @@ -172,7 +172,7 @@ public class SendsController : Controller } catch (Exception e) { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + _logger.LogError(e, "Uncaught exception occurred while handling event grid event: {Event}", JsonSerializer.Serialize(eventGridEvent)); return; } } diff --git a/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs new file mode 100644 index 0000000000..9f6a8d2639 --- /dev/null +++ b/src/Api/Utilities/DiagnosticTools/EventDiagnosticLogger.cs @@ -0,0 +1,87 @@ +using Bit.Api.Models.Public.Request; +using Bit.Api.Models.Public.Response; +using Bit.Core; +using Bit.Core.Services; + +namespace Bit.Api.Utilities.DiagnosticTools; + +public static class EventDiagnosticLogger +{ + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + PagedListResponseModel data, EventFilterRequestModel request) + { + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.Data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(data.ContinuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd} ActingUserId:{ActingUserId} ItemId:{ItemId},", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + request.Start?.ToString("o"), + request.End?.ToString("o"), + request.ActingUserId, + request.ItemId); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } + + public static void LogAggregateData( + this ILogger logger, + IFeatureService featureService, + Guid organizationId, + IEnumerable data, + string? continuationToken, + DateTime? queryStart = null, + DateTime? queryEnd = null) + { + + try + { + if (!featureService.IsEnabled(FeatureFlagKeys.EventDiagnosticLogging)) + { + return; + } + + var orderedRecords = data.OrderBy(e => e.Date).ToList(); + var recordCount = orderedRecords.Count; + var newestRecordDate = orderedRecords.LastOrDefault()?.Date.ToString("o"); + var oldestRecordDate = orderedRecords.FirstOrDefault()?.Date.ToString("o"); ; + var hasMore = !string.IsNullOrEmpty(continuationToken); + + logger.LogInformation( + "Events query for Organization:{OrgId}. Event count:{Count} newest record:{newestRecord} oldest record:{oldestRecord} HasMore:{HasMore} " + + "Request Filters Start:{QueryStart} End:{QueryEnd}", + organizationId, + recordCount, + newestRecordDate, + oldestRecordDate, + hasMore, + queryStart?.ToString("o"), + queryEnd?.ToString("o")); + } + catch (Exception exception) + { + logger.LogWarning(exception, "Unexpected exception from EventDiagnosticLogger.LogAggregateData"); + } + } +} diff --git a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs index 91079d5040..1caa7cf841 100644 --- a/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/Api/Utilities/ExceptionHandlerFilterAttribute.cs @@ -152,7 +152,7 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute else { var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); errorMessage = "An unhandled server error has occurred."; context.HttpContext.Response.StatusCode = 500; } diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index c0a974bce2..0983225f84 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using System.Globalization; using System.Text.Json; using Azure.Messaging.EventGrid; using Bit.Api.Auth.Models.Request.Accounts; @@ -401,8 +402,9 @@ public class CiphersController : Controller { var org = _currentContext.GetOrganization(organizationId); - // If we're not an "admin" or if we're not a provider user we don't need to check the ciphers - if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId)) + // If we're not an "admin" we don't need to check the ciphers + if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or + { Permissions.EditAnyCollection: true })) { return false; } @@ -415,8 +417,9 @@ public class CiphersController : Controller { var org = _currentContext.GetOrganization(organizationId); - // If we're not an "admin" or if we're a provider user we don't need to check the ciphers - if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or { Permissions.EditAnyCollection: true }) || await _currentContext.ProviderUserForOrgAsync(organizationId)) + // If we're not an "admin" we don't need to check the ciphers + if (org is not ({ Type: OrganizationUserType.Owner or OrganizationUserType.Admin } or + { Permissions.EditAnyCollection: true })) { return false; } @@ -754,6 +757,11 @@ public class CiphersController : Controller } } + if (cipher.ArchivedDate.HasValue) + { + throw new BadRequestException("Cannot move an archived item to an organization."); + } + ValidateClientVersionForFido2CredentialSupport(cipher); var original = cipher.Clone(); @@ -887,6 +895,9 @@ public class CiphersController : Controller [HttpPost("bulk-collections")] public async Task PostBulkCollections([FromBody] CipherBulkUpdateCollectionsRequestModel model) { + var userId = _userService.GetProperUserId(User).Value; + await _cipherService.ValidateBulkCollectionAssignmentAsync(model.CollectionIds, model.CipherIds, userId); + if (!await CanModifyCipherCollectionsAsync(model.OrganizationId, model.CipherIds) || !await CanEditItemsInCollections(model.OrganizationId, model.CollectionIds)) { @@ -1260,6 +1271,11 @@ public class CiphersController : Controller _logger.LogError("Cipher was not encrypted for the current user. CipherId: {CipherId}, CurrentUser: {CurrentUserId}, EncryptedFor: {EncryptedFor}", cipher.Id, userId, cipher.EncryptedFor); throw new BadRequestException("Cipher was not encrypted for the current user. Please try again."); } + + if (cipher.ArchivedDate.HasValue) + { + throw new BadRequestException("Cannot move archived items to an organization."); + } } var shareCiphers = new List<(CipherDetails, DateTime?)>(); @@ -1272,6 +1288,11 @@ public class CiphersController : Controller ValidateClientVersionForFido2CredentialSupport(existingCipher); + if (existingCipher.ArchivedDate.HasValue) + { + throw new BadRequestException("Cannot move archived items to an organization."); + } + shareCiphers.Add((cipher.ToCipherDetails(existingCipher), cipher.LastKnownRevisionDate)); } @@ -1348,7 +1369,7 @@ public class CiphersController : Controller } var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, - request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id); + request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id, request.LastKnownRevisionDate); return new AttachmentUploadDataResponseModel { AttachmentId = attachmentId, @@ -1401,9 +1422,11 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream) => { - await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); + await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData, lastKnownRevisionDate); }); } @@ -1422,10 +1445,12 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), user.Id); + Request.ContentLength.GetValueOrDefault(0), user.Id, false, lastKnownRevisionDate); }); return new CipherResponseModel( @@ -1451,10 +1476,13 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); + await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId, true); + Request.ContentLength.GetValueOrDefault(0), userId, true, lastKnownRevisionDate); }); return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); @@ -1497,10 +1525,13 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); + await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentShareAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); + Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId, lastKnownRevisionDate); }); } @@ -1575,7 +1606,7 @@ public class CiphersController : Controller } catch (Exception e) { - _logger.LogError(e, $"Uncaught exception occurred while handling event grid event: {JsonSerializer.Serialize(eventGridEvent)}"); + _logger.LogError(e, "Uncaught exception occurred while handling event grid event: {Event}", JsonSerializer.Serialize(eventGridEvent)); return; } } @@ -1612,4 +1643,19 @@ public class CiphersController : Controller { return await _cipherRepository.GetByIdAsync(cipherId, userId); } + + private DateTime? GetLastKnownRevisionDateFromForm() + { + DateTime? lastKnownRevisionDate = null; + if (Request.Form.TryGetValue("lastKnownRevisionDate", out var dateValue)) + { + if (!DateTime.TryParse(dateValue, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out var parsedDate)) + { + throw new BadRequestException("Invalid lastKnownRevisionDate format."); + } + lastKnownRevisionDate = parsedDate; + } + + return lastKnownRevisionDate; + } } diff --git a/src/Api/Vault/Controllers/SyncController.cs b/src/Api/Vault/Controllers/SyncController.cs index 54f1b9e70b..6ac8d06ba0 100644 --- a/src/Api/Vault/Controllers/SyncController.cs +++ b/src/Api/Vault/Controllers/SyncController.cs @@ -11,6 +11,8 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Services; @@ -42,6 +44,7 @@ public class SyncController : Controller private readonly IFeatureService _featureService; private readonly IApplicationCacheService _applicationCacheService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; public SyncController( IUserService userService, @@ -57,7 +60,8 @@ public class SyncController : Controller ICurrentContext currentContext, IFeatureService featureService, IApplicationCacheService applicationCacheService, - ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery) + ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IUserAccountKeysQuery userAccountKeysQuery) { _userService = userService; _folderRepository = folderRepository; @@ -73,6 +77,7 @@ public class SyncController : Controller _featureService = featureService; _applicationCacheService = applicationCacheService; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; + _userAccountKeysQuery = userAccountKeysQuery; } [HttpGet("")] @@ -116,7 +121,14 @@ public class SyncController : Controller var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); - var response = new SyncResponseModel(_globalSettings, user, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities, + UserAccountKeysData userAccountKeys = null; + // JIT TDE users and some broken/old users may not have a private key. + if (!string.IsNullOrWhiteSpace(user.PrivateKey)) + { + userAccountKeys = await _userAccountKeysQuery.Run(user); + } + + var response = new SyncResponseModel(_globalSettings, user, userAccountKeys, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities, organizationIdsClaimingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends); return response; diff --git a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs index 96c66c6044..eef70bf4e4 100644 --- a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs +++ b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs @@ -9,4 +9,9 @@ public class AttachmentRequestModel public string FileName { get; set; } public long FileSize { get; set; } public bool AdminRequest { get; set; } = false; + + /// + /// The last known revision date of the Cipher that this attachment belongs to. + /// + public DateTime? LastKnownRevisionDate { get; set; } } diff --git a/src/Api/Vault/Models/Response/SyncResponseModel.cs b/src/Api/Vault/Models/Response/SyncResponseModel.cs index e5b2ab55e3..c965320b94 100644 --- a/src/Api/Vault/Models/Response/SyncResponseModel.cs +++ b/src/Api/Vault/Models/Response/SyncResponseModel.cs @@ -7,7 +7,8 @@ using Bit.Api.Tools.Models.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.Entities; -using Bit.Core.KeyManagement.Models.Response; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Models.Api; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; @@ -24,6 +25,7 @@ public class SyncResponseModel() : ResponseModel("sync") public SyncResponseModel( GlobalSettings globalSettings, User user, + UserAccountKeysData userAccountKeysData, bool userTwoFactorEnabled, bool userHasPremiumFromOrganization, IDictionary organizationAbilities, @@ -40,7 +42,7 @@ public class SyncResponseModel() : ResponseModel("sync") IEnumerable sends) : this() { - Profile = new ProfileResponseModel(user, organizationUserDetails, providerUserDetails, + Profile = new ProfileResponseModel(user, userAccountKeysData, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationIdsClaimingingUser); Folders = folders.Select(f => new FolderResponseModel(f)); Ciphers = ciphers.Select(cipher => diff --git a/src/Api/appsettings.json b/src/Api/appsettings.json index f8a69dcfac..98bb4df8ac 100644 --- a/src/Api/appsettings.json +++ b/src/Api/appsettings.json @@ -64,7 +64,8 @@ "bitPay": { "production": false, "token": "SECRET", - "notificationUrl": "https://bitwarden.com/SECRET" + "notificationUrl": "https://bitwarden.com/SECRET", + "webhookKey": "SECRET" }, "amazon": { "accessKeyId": "SECRET", diff --git a/src/Billing/BillingSettings.cs b/src/Billing/BillingSettings.cs index 32630e4a4a..64a52ed290 100644 --- a/src/Billing/BillingSettings.cs +++ b/src/Billing/BillingSettings.cs @@ -7,10 +7,7 @@ public class BillingSettings { public virtual string JobsKey { get; set; } public virtual string StripeWebhookKey { get; set; } - public virtual string StripeWebhookSecret { get; set; } - public virtual string StripeWebhookSecret20231016 { get; set; } - public virtual string StripeWebhookSecret20240620 { get; set; } - public virtual string BitPayWebhookKey { get; set; } + public virtual string StripeWebhookSecret20250827Basil { get; set; } public virtual string AppleWebhookKey { get; set; } public virtual FreshDeskSettings FreshDesk { get; set; } = new FreshDeskSettings(); public virtual string FreshsalesApiKey { get; set; } @@ -44,6 +41,15 @@ public class BillingSettings { public virtual string ApiKey { get; set; } public virtual string BaseUrl { get; set; } + public virtual string Path { get; set; } public virtual int PersonaId { get; set; } + public virtual bool UseAnswerWithCitationModels { get; set; } = true; + + public virtual SearchSettings SearchSettings { get; set; } = new SearchSettings(); + } + public class SearchSettings + { + public virtual string RunSearch { get; set; } = "auto"; // "always", "never", "auto" + public virtual bool RealTime { get; set; } = true; } } diff --git a/src/Billing/Constants/BitPayInvoiceStatus.cs b/src/Billing/Constants/BitPayInvoiceStatus.cs deleted file mode 100644 index b9c1e5834d..0000000000 --- a/src/Billing/Constants/BitPayInvoiceStatus.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Billing.Constants; - -public static class BitPayInvoiceStatus -{ - public const string Confirmed = "confirmed"; - public const string Complete = "complete"; -} diff --git a/src/Billing/Controllers/BitPayController.cs b/src/Billing/Controllers/BitPayController.cs index 111ffabc2b..b24a8d8c36 100644 --- a/src/Billing/Controllers/BitPayController.cs +++ b/src/Billing/Controllers/BitPayController.cs @@ -1,125 +1,79 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Globalization; -using Bit.Billing.Constants; +using System.Globalization; using Bit.Billing.Models; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Payment.Clients; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Settings; using Bit.Core.Utilities; +using BitPayLight.Models.Invoice; using Microsoft.AspNetCore.Mvc; using Microsoft.Data.SqlClient; -using Microsoft.Extensions.Options; namespace Bit.Billing.Controllers; +using static BitPayConstants; +using static StripeConstants; + [Route("bitpay")] [ApiExplorerSettings(IgnoreApi = true)] -public class BitPayController : Controller +public class BitPayController( + GlobalSettings globalSettings, + IBitPayClient bitPayClient, + ITransactionRepository transactionRepository, + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IProviderRepository providerRepository, + IMailService mailService, + IPaymentService paymentService, + ILogger logger, + IPremiumUserBillingService premiumUserBillingService) + : Controller { - private readonly BillingSettings _billingSettings; - private readonly BitPayClient _bitPayClient; - private readonly ITransactionRepository _transactionRepository; - private readonly IOrganizationRepository _organizationRepository; - private readonly IUserRepository _userRepository; - private readonly IProviderRepository _providerRepository; - private readonly IMailService _mailService; - private readonly IPaymentService _paymentService; - private readonly ILogger _logger; - private readonly IPremiumUserBillingService _premiumUserBillingService; - - public BitPayController( - IOptions billingSettings, - BitPayClient bitPayClient, - ITransactionRepository transactionRepository, - IOrganizationRepository organizationRepository, - IUserRepository userRepository, - IProviderRepository providerRepository, - IMailService mailService, - IPaymentService paymentService, - ILogger logger, - IPremiumUserBillingService premiumUserBillingService) - { - _billingSettings = billingSettings?.Value; - _bitPayClient = bitPayClient; - _transactionRepository = transactionRepository; - _organizationRepository = organizationRepository; - _userRepository = userRepository; - _providerRepository = providerRepository; - _mailService = mailService; - _paymentService = paymentService; - _logger = logger; - _premiumUserBillingService = premiumUserBillingService; - } - [HttpPost("ipn")] public async Task PostIpn([FromBody] BitPayEventModel model, [FromQuery] string key) { - if (!CoreHelpers.FixedTimeEquals(key, _billingSettings.BitPayWebhookKey)) + if (!CoreHelpers.FixedTimeEquals(key, globalSettings.BitPay.WebhookKey)) { - return new BadRequestResult(); - } - if (model == null || string.IsNullOrWhiteSpace(model.Data?.Id) || - string.IsNullOrWhiteSpace(model.Event?.Name)) - { - return new BadRequestResult(); + return new BadRequestObjectResult("Invalid key"); } - if (model.Event.Name != BitPayNotificationCode.InvoiceConfirmed) - { - // Only processing confirmed invoice events for now. - return new OkResult(); - } - - var invoice = await _bitPayClient.GetInvoiceAsync(model.Data.Id); - if (invoice == null) - { - // Request forged...? - _logger.LogWarning("Invoice not found. #{InvoiceId}", model.Data.Id); - return new BadRequestResult(); - } - - if (invoice.Status != BitPayInvoiceStatus.Confirmed && invoice.Status != BitPayInvoiceStatus.Complete) - { - _logger.LogWarning("Invoice status of '{InvoiceStatus}' is not acceptable. #{InvoiceId}", invoice.Status, invoice.Id); - return new BadRequestResult(); - } + var invoice = await bitPayClient.GetInvoice(model.Data.Id); if (invoice.Currency != "USD") { - // Only process USD payments - _logger.LogWarning("Non USD payment received. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogWarning("Received BitPay invoice webhook for invoice ({InvoiceID}) with non-USD currency: {Currency}", invoice.Id, invoice.Currency); + return new BadRequestObjectResult("Cannot process non-USD payments"); } var (organizationId, userId, providerId) = GetIdsFromPosData(invoice); - if (!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) + if ((!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) || !invoice.PosData.Contains(PosDataKeys.AccountCredit)) { - return new OkResult(); + logger.LogWarning("Received BitPay invoice webhook for invoice ({InvoiceID}) that had invalid POS data: {PosData}", invoice.Id, invoice.PosData); + return new BadRequestObjectResult("Invalid POS data"); } - var isAccountCredit = IsAccountCredit(invoice); - if (!isAccountCredit) + if (invoice.Status != InvoiceStatuses.Complete) { - // Only processing credits - _logger.LogWarning("Non-credit payment received. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogInformation("Received valid BitPay invoice webhook for invoice ({InvoiceID}) that is not yet complete: {Status}", + invoice.Id, invoice.Status); + return new OkObjectResult("Waiting for invoice to be completed"); } - var transaction = await _transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); - if (transaction != null) + var existingTransaction = await transactionRepository.GetByGatewayIdAsync(GatewayType.BitPay, invoice.Id); + if (existingTransaction != null) { - _logger.LogWarning("Already processed this invoice. #{InvoiceId}", invoice.Id); - return new OkResult(); + logger.LogWarning("Already processed BitPay invoice webhook for invoice ({InvoiceID})", invoice.Id); + return new OkObjectResult("Invoice already processed"); } try { - var tx = new Transaction + var transaction = new Transaction { Amount = Convert.ToDecimal(invoice.Price), CreationDate = GetTransactionDate(invoice), @@ -132,50 +86,47 @@ public class BitPayController : Controller PaymentMethodType = PaymentMethodType.BitPay, Details = $"{invoice.Currency}, BitPay {invoice.Id}" }; - await _transactionRepository.CreateAsync(tx); - string billingEmail = null; - if (tx.OrganizationId.HasValue) + await transactionRepository.CreateAsync(transaction); + + var billingEmail = ""; + if (transaction.OrganizationId.HasValue) { - var org = await _organizationRepository.GetByIdAsync(tx.OrganizationId.Value); - if (org != null) + var organization = await organizationRepository.GetByIdAsync(transaction.OrganizationId.Value); + if (organization != null) { - billingEmail = org.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(org, tx.Amount)) + billingEmail = organization.BillingEmailAddress(); + if (await paymentService.CreditAccountAsync(organization, transaction.Amount)) { - await _organizationRepository.ReplaceAsync(org); + await organizationRepository.ReplaceAsync(organization); } } } - else if (tx.UserId.HasValue) + else if (transaction.UserId.HasValue) { - var user = await _userRepository.GetByIdAsync(tx.UserId.Value); + var user = await userRepository.GetByIdAsync(transaction.UserId.Value); if (user != null) { billingEmail = user.BillingEmailAddress(); - await _premiumUserBillingService.Credit(user, tx.Amount); + await premiumUserBillingService.Credit(user, transaction.Amount); } } - else if (tx.ProviderId.HasValue) + else if (transaction.ProviderId.HasValue) { - var provider = await _providerRepository.GetByIdAsync(tx.ProviderId.Value); + var provider = await providerRepository.GetByIdAsync(transaction.ProviderId.Value); if (provider != null) { billingEmail = provider.BillingEmailAddress(); - if (await _paymentService.CreditAccountAsync(provider, tx.Amount)) + if (await paymentService.CreditAccountAsync(provider, transaction.Amount)) { - await _providerRepository.ReplaceAsync(provider); + await providerRepository.ReplaceAsync(provider); } } } - else - { - _logger.LogError("Received BitPay account credit transaction that didn't have a user, org, or provider. Invoice#{InvoiceId}", invoice.Id); - } if (!string.IsNullOrWhiteSpace(billingEmail)) { - await _mailService.SendAddedCreditAsync(billingEmail, tx.Amount); + await mailService.SendAddedCreditAsync(billingEmail, transaction.Amount); } } // Catch foreign key violations because user/org could have been deleted. @@ -186,58 +137,34 @@ public class BitPayController : Controller return new OkResult(); } - private bool IsAccountCredit(BitPayLight.Models.Invoice.Invoice invoice) + private static DateTime GetTransactionDate(Invoice invoice) { - return invoice != null && invoice.PosData != null && invoice.PosData.Contains("accountCredit:1"); + var transactions = invoice.Transactions?.Where(transaction => + transaction.Type == null && !string.IsNullOrWhiteSpace(transaction.Confirmations) && + transaction.Confirmations != "0").ToList(); + + return transactions?.Count == 1 + ? DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind) + : CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); } - private DateTime GetTransactionDate(BitPayLight.Models.Invoice.Invoice invoice) + public (Guid? OrganizationId, Guid? UserId, Guid? ProviderId) GetIdsFromPosData(Invoice invoice) { - var transactions = invoice.Transactions?.Where(t => t.Type == null && - !string.IsNullOrWhiteSpace(t.Confirmations) && t.Confirmations != "0"); - if (transactions != null && transactions.Count() == 1) + if (invoice.PosData is null or { Length: 0 } || !invoice.PosData.Contains(':')) { - return DateTime.Parse(transactions.First().ReceivedTime, CultureInfo.InvariantCulture, - DateTimeStyles.RoundtripKind); - } - return CoreHelpers.FromEpocMilliseconds(invoice.CurrentTime); - } - - public Tuple GetIdsFromPosData(BitPayLight.Models.Invoice.Invoice invoice) - { - Guid? orgId = null; - Guid? userId = null; - Guid? providerId = null; - - if (invoice == null || string.IsNullOrWhiteSpace(invoice.PosData) || !invoice.PosData.Contains(':')) - { - return new Tuple(null, null, null); + return new ValueTuple(null, null, null); } - var mainParts = invoice.PosData.Split(','); - foreach (var mainPart in mainParts) - { - var parts = mainPart.Split(':'); + var ids = invoice.PosData + .Split(',') + .Select(part => part.Split(':')) + .Where(parts => parts.Length == 2 && Guid.TryParse(parts[1], out _)) + .ToDictionary(parts => parts[0], parts => Guid.Parse(parts[1])); - if (parts.Length <= 1 || !Guid.TryParse(parts[1], out var id)) - { - continue; - } - - switch (parts[0]) - { - case "userId": - userId = id; - break; - case "organizationId": - orgId = id; - break; - case "providerId": - providerId = id; - break; - } - } - - return new Tuple(orgId, userId, providerId); + return new ValueTuple( + ids.TryGetValue(MetadataKeys.OrganizationId, out var id) ? id : null, + ids.TryGetValue(MetadataKeys.UserId, out id) ? id : null, + ids.TryGetValue(MetadataKeys.ProviderId, out id) ? id : null + ); } } diff --git a/src/Billing/Controllers/FreshdeskController.cs b/src/Billing/Controllers/FreshdeskController.cs index 66d4f47d92..38ed05cfdf 100644 --- a/src/Billing/Controllers/FreshdeskController.cs +++ b/src/Billing/Controllers/FreshdeskController.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using System.Net.Http.Headers; using System.Reflection; using System.Text; @@ -35,7 +32,7 @@ public class FreshdeskController : Controller GlobalSettings globalSettings, IHttpClientFactory httpClientFactory) { - _billingSettings = billingSettings?.Value; + _billingSettings = billingSettings?.Value ?? throw new ArgumentNullException(nameof(billingSettings)); _userRepository = userRepository; _organizationRepository = organizationRepository; _logger = logger; @@ -101,7 +98,8 @@ public class FreshdeskController : Controller customFields[_billingSettings.FreshDesk.OrgFieldName] += $"\n{orgNote}"; } - var planName = GetAttribute(org.PlanType).Name.Split(" ").FirstOrDefault(); + var displayAttribute = GetAttribute(org.PlanType); + var planName = displayAttribute?.Name?.Split(" ").FirstOrDefault(); if (!string.IsNullOrWhiteSpace(planName)) { tags.Add(string.Format("Org: {0}", planName)); @@ -159,28 +157,22 @@ public class FreshdeskController : Controller return Ok(); } - // create the onyx `answer-with-citation` request - var onyxRequestModel = new OnyxAnswerWithCitationRequestModel(model.TicketDescriptionText, _billingSettings.Onyx.PersonaId); - var onyxRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("{0}/query/answer-with-citation", _billingSettings.Onyx.BaseUrl)) - { - Content = JsonContent.Create(onyxRequestModel, mediaType: new MediaTypeHeaderValue("application/json")), - }; - var (_, onyxJsonResponse) = await CallOnyxApi(onyxRequest); + // Get response from Onyx AI + var (onyxRequest, onyxResponse) = await GetAnswerFromOnyx(model); // the CallOnyxApi will return a null if we have an error response - if (onyxJsonResponse?.Answer == null || !string.IsNullOrEmpty(onyxJsonResponse?.ErrorMsg)) + if (onyxResponse?.Answer == null || !string.IsNullOrEmpty(onyxResponse?.ErrorMsg)) { _logger.LogWarning("Error getting answer from Onyx AI. Freshdesk model: {model}\r\n Onyx query {query}\r\nresponse: {response}. ", JsonSerializer.Serialize(model), - JsonSerializer.Serialize(onyxRequestModel), - JsonSerializer.Serialize(onyxJsonResponse)); + JsonSerializer.Serialize(onyxRequest), + JsonSerializer.Serialize(onyxResponse)); return Ok(); // return ok so we don't retry } // add the answer as a note to the ticket - await AddAnswerNoteToTicketAsync(onyxJsonResponse.Answer, model.TicketId); + await AddAnswerNoteToTicketAsync(onyxResponse?.Answer ?? string.Empty, model.TicketId); return Ok(); } @@ -206,27 +198,21 @@ public class FreshdeskController : Controller } // create the onyx `answer-with-citation` request - var onyxRequestModel = new OnyxAnswerWithCitationRequestModel(model.TicketDescriptionText, _billingSettings.Onyx.PersonaId); - var onyxRequest = new HttpRequestMessage(HttpMethod.Post, - string.Format("{0}/query/answer-with-citation", _billingSettings.Onyx.BaseUrl)) - { - Content = JsonContent.Create(onyxRequestModel, mediaType: new MediaTypeHeaderValue("application/json")), - }; - var (_, onyxJsonResponse) = await CallOnyxApi(onyxRequest); + var (onyxRequest, onyxResponse) = await GetAnswerFromOnyx(model); // the CallOnyxApi will return a null if we have an error response - if (onyxJsonResponse?.Answer == null || !string.IsNullOrEmpty(onyxJsonResponse?.ErrorMsg)) + if (onyxResponse?.Answer == null || !string.IsNullOrEmpty(onyxResponse?.ErrorMsg)) { _logger.LogWarning("Error getting answer from Onyx AI. Freshdesk model: {model}\r\n Onyx query {query}\r\nresponse: {response}. ", JsonSerializer.Serialize(model), - JsonSerializer.Serialize(onyxRequestModel), - JsonSerializer.Serialize(onyxJsonResponse)); + JsonSerializer.Serialize(onyxRequest), + JsonSerializer.Serialize(onyxResponse)); return Ok(); // return ok so we don't retry } // add the reply to the ticket - await AddReplyToTicketAsync(onyxJsonResponse.Answer, model.TicketId); + await AddReplyToTicketAsync(onyxResponse?.Answer ?? string.Empty, model.TicketId); return Ok(); } @@ -356,7 +342,32 @@ public class FreshdeskController : Controller return await CallFreshdeskApiAsync(request, retriedCount++); } - private async Task<(HttpResponseMessage, T)> CallOnyxApi(HttpRequestMessage request) + async Task<(OnyxRequestModel onyxRequest, OnyxResponseModel onyxResponse)> GetAnswerFromOnyx(FreshdeskOnyxAiWebhookModel model) + { + // TODO: remove the use of the deprecated answer-with-citation models after we are sure + if (_billingSettings.Onyx.UseAnswerWithCitationModels) + { + var onyxRequest = new OnyxAnswerWithCitationRequestModel(model.TicketDescriptionText, _billingSettings.Onyx); + var onyxAnswerWithCitationRequest = new HttpRequestMessage(HttpMethod.Post, + string.Format("{0}/query/answer-with-citation", _billingSettings.Onyx.BaseUrl)) + { + Content = JsonContent.Create(onyxRequest, mediaType: new MediaTypeHeaderValue("application/json")), + }; + var onyxResponse = await CallOnyxApi(onyxAnswerWithCitationRequest); + return (onyxRequest, onyxResponse); + } + + var request = new OnyxSendMessageSimpleApiRequestModel(model.TicketDescriptionText, _billingSettings.Onyx); + var onyxSimpleRequest = new HttpRequestMessage(HttpMethod.Post, + string.Format("{0}{1}", _billingSettings.Onyx.BaseUrl, _billingSettings.Onyx.Path)) + { + Content = JsonContent.Create(request, mediaType: new MediaTypeHeaderValue("application/json")), + }; + var onyxSimpleResponse = await CallOnyxApi(onyxSimpleRequest); + return (request, onyxSimpleResponse); + } + + private async Task CallOnyxApi(HttpRequestMessage request) where T : class, new() { var httpClient = _httpClientFactory.CreateClient("OnyxApi"); var response = await httpClient.SendAsync(request); @@ -365,7 +376,7 @@ public class FreshdeskController : Controller { _logger.LogError("Error calling Onyx AI API. Status code: {0}. Response {1}", response.StatusCode, JsonSerializer.Serialize(response)); - return (null, default); + return new T(); } var responseStr = await response.Content.ReadAsStringAsync(); var responseJson = JsonSerializer.Deserialize(responseStr, options: new JsonSerializerOptions @@ -373,11 +384,12 @@ public class FreshdeskController : Controller PropertyNameCaseInsensitive = true, }); - return (response, responseJson); + return responseJson ?? new T(); } - private TAttribute GetAttribute(Enum enumValue) where TAttribute : Attribute + private TAttribute? GetAttribute(Enum enumValue) where TAttribute : Attribute { - return enumValue.GetType().GetMember(enumValue.ToString()).First().GetCustomAttribute(); + var memberInfo = enumValue.GetType().GetMember(enumValue.ToString()).FirstOrDefault(); + return memberInfo != null ? memberInfo.GetCustomAttribute() : null; } } diff --git a/src/Billing/Controllers/FreshsalesController.cs b/src/Billing/Controllers/FreshsalesController.cs index be5a9ddb16..68382fbd5d 100644 --- a/src/Billing/Controllers/FreshsalesController.cs +++ b/src/Billing/Controllers/FreshsalesController.cs @@ -158,6 +158,7 @@ public class FreshsalesController : Controller planName = "Free"; return true; case PlanType.FamiliesAnnually: + case PlanType.FamiliesAnnually2025: case PlanType.FamiliesAnnually2019: planName = "Families"; return true; diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index b60e0c56e4..18f2198119 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -120,9 +120,7 @@ public class StripeController : Controller return deliveryContainer.ApiVersion switch { - "2024-06-20" => HandleVersionWith(_billingSettings.StripeWebhookSecret20240620), - "2023-10-16" => HandleVersionWith(_billingSettings.StripeWebhookSecret20231016), - "2022-08-01" => HandleVersionWith(_billingSettings.StripeWebhookSecret), + "2025-08-27.basil" => HandleVersionWith(_billingSettings.StripeWebhookSecret20250827Basil), _ => HandleDefault(deliveryContainer.ApiVersion) }; diff --git a/src/Billing/Jobs/ProviderOrganizationDisableJob.cs b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs new file mode 100644 index 0000000000..5a48dd609f --- /dev/null +++ b/src/Billing/Jobs/ProviderOrganizationDisableJob.cs @@ -0,0 +1,88 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Quartz; + +namespace Bit.Billing.Jobs; + +public class ProviderOrganizationDisableJob( + IProviderOrganizationRepository providerOrganizationRepository, + IOrganizationDisableCommand organizationDisableCommand, + ILogger logger) + : IJob +{ + private const int MaxConcurrency = 5; + private const int MaxTimeoutMinutes = 10; + + public async Task Execute(IJobExecutionContext context) + { + var providerId = new Guid(context.MergedJobDataMap.GetString("providerId") ?? string.Empty); + var expirationDateString = context.MergedJobDataMap.GetString("expirationDate"); + DateTime? expirationDate = string.IsNullOrEmpty(expirationDateString) + ? null + : DateTime.Parse(expirationDateString); + + logger.LogInformation("Starting to disable organizations for provider {ProviderId}", providerId); + + var startTime = DateTime.UtcNow; + var totalProcessed = 0; + var totalErrors = 0; + + try + { + var providerOrganizations = await providerOrganizationRepository + .GetManyDetailsByProviderAsync(providerId); + + if (providerOrganizations == null || !providerOrganizations.Any()) + { + logger.LogInformation("No organizations found for provider {ProviderId}", providerId); + return; + } + + logger.LogInformation("Disabling {OrganizationCount} organizations for provider {ProviderId}", + providerOrganizations.Count, providerId); + + var semaphore = new SemaphoreSlim(MaxConcurrency, MaxConcurrency); + var tasks = providerOrganizations.Select(async po => + { + if (DateTime.UtcNow.Subtract(startTime).TotalMinutes > MaxTimeoutMinutes) + { + logger.LogWarning("Timeout reached while disabling organizations for provider {ProviderId}", providerId); + return false; + } + + await semaphore.WaitAsync(); + try + { + await organizationDisableCommand.DisableAsync(po.OrganizationId, expirationDate); + Interlocked.Increment(ref totalProcessed); + return true; + } + catch (Exception ex) + { + logger.LogError(ex, "Failed to disable organization {OrganizationId} for provider {ProviderId}", + po.OrganizationId, providerId); + Interlocked.Increment(ref totalErrors); + return false; + } + finally + { + semaphore.Release(); + } + }); + + await Task.WhenAll(tasks); + + logger.LogInformation("Completed disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + } + catch (Exception ex) + { + logger.LogError(ex, "Error disabling organizations for provider {ProviderId}. Processed: {TotalProcessed}, Errors: {TotalErrors}", + providerId, totalProcessed, totalErrors); + throw; + } + } +} diff --git a/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs b/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs index ba3b89e297..9a753be4bc 100644 --- a/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs +++ b/src/Billing/Models/OnyxAnswerWithCitationRequestModel.cs @@ -1,35 +1,58 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; +using static Bit.Billing.BillingSettings; namespace Bit.Billing.Models; -public class OnyxAnswerWithCitationRequestModel +public class OnyxRequestModel { - [JsonPropertyName("messages")] - public List Messages { get; set; } - [JsonPropertyName("persona_id")] public int PersonaId { get; set; } = 1; [JsonPropertyName("retrieval_options")] - public RetrievalOptions RetrievalOptions { get; set; } + public RetrievalOptions RetrievalOptions { get; set; } = new RetrievalOptions(); - public OnyxAnswerWithCitationRequestModel(string message, int personaId = 1) + public OnyxRequestModel(OnyxSettings onyxSettings) + { + PersonaId = onyxSettings.PersonaId; + RetrievalOptions.RunSearch = onyxSettings.SearchSettings.RunSearch; + RetrievalOptions.RealTime = onyxSettings.SearchSettings.RealTime; + } +} + +/// +/// This is used with the onyx endpoint /query/answer-with-citation +/// which has been deprecated. This can be removed once later +/// +public class OnyxAnswerWithCitationRequestModel : OnyxRequestModel +{ + [JsonPropertyName("messages")] + public List Messages { get; set; } = new List(); + + public OnyxAnswerWithCitationRequestModel(string message, OnyxSettings onyxSettings) : base(onyxSettings) { message = message.Replace(Environment.NewLine, " ").Replace('\r', ' ').Replace('\n', ' '); Messages = new List() { new Message() { MessageText = message } }; - RetrievalOptions = new RetrievalOptions(); - PersonaId = personaId; + } +} + +/// +/// This is used with the onyx endpoint /chat/send-message-simple-api +/// +public class OnyxSendMessageSimpleApiRequestModel : OnyxRequestModel +{ + [JsonPropertyName("message")] + public string Message { get; set; } = string.Empty; + + public OnyxSendMessageSimpleApiRequestModel(string message, OnyxSettings onyxSettings) : base(onyxSettings) + { + Message = message.Replace(Environment.NewLine, " ").Replace('\r', ' ').Replace('\n', ' '); } } public class Message { [JsonPropertyName("message")] - public string MessageText { get; set; } + public string MessageText { get; set; } = string.Empty; [JsonPropertyName("sender")] public string Sender { get; set; } = "user"; diff --git a/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs b/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs deleted file mode 100644 index 5f67cd51d2..0000000000 --- a/src/Billing/Models/OnyxAnswerWithCitationResponseModel.cs +++ /dev/null @@ -1,33 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; - -namespace Bit.Billing.Models; - -public class OnyxAnswerWithCitationResponseModel -{ - [JsonPropertyName("answer")] - public string Answer { get; set; } - - [JsonPropertyName("rephrase")] - public string Rephrase { get; set; } - - [JsonPropertyName("citations")] - public List Citations { get; set; } - - [JsonPropertyName("llm_selected_doc_indices")] - public List LlmSelectedDocIndices { get; set; } - - [JsonPropertyName("error_msg")] - public string ErrorMsg { get; set; } -} - -public class Citation -{ - [JsonPropertyName("citation_num")] - public int CitationNum { get; set; } - - [JsonPropertyName("document_id")] - public string DocumentId { get; set; } -} diff --git a/src/Billing/Models/OnyxResponseModel.cs b/src/Billing/Models/OnyxResponseModel.cs new file mode 100644 index 0000000000..96fa134c40 --- /dev/null +++ b/src/Billing/Models/OnyxResponseModel.cs @@ -0,0 +1,15 @@ +using System.Text.Json.Serialization; + +namespace Bit.Billing.Models; + +public class OnyxResponseModel +{ + [JsonPropertyName("answer")] + public string Answer { get; set; } = string.Empty; + + [JsonPropertyName("answer_citationless")] + public string AnswerCitationless { get; set; } = string.Empty; + + [JsonPropertyName("error_msg")] + public string ErrorMsg { get; set; } = string.Empty; +} diff --git a/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs b/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs index 5bb098bec5..101b0e26b9 100644 --- a/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs +++ b/src/Billing/Services/Implementations/InvoiceCreatedHandler.cs @@ -1,4 +1,5 @@ -using Event = Stripe.Event; +using Bit.Core.Billing.Constants; +using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -35,13 +36,13 @@ public class InvoiceCreatedHandler( if (usingPayPal && invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_create" or "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null and not "" + Parent.SubscriptionDetails: not null }) { await stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); diff --git a/src/Billing/Services/Implementations/PaymentFailedHandler.cs b/src/Billing/Services/Implementations/PaymentFailedHandler.cs index acf6ca70c7..0da6d03e94 100644 --- a/src/Billing/Services/Implementations/PaymentFailedHandler.cs +++ b/src/Billing/Services/Implementations/PaymentFailedHandler.cs @@ -1,4 +1,5 @@ -using Stripe; +using Bit.Core.Billing.Constants; +using Stripe; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -26,17 +27,20 @@ public class PaymentFailedHandler : IPaymentFailedHandler public async Task HandleAsync(Event parsedEvent) { var invoice = await _stripeEventService.GetInvoice(parsedEvent, true); - if (invoice.Paid || invoice.AttemptCount <= 1 || !ShouldAttemptToPayInvoice(invoice)) + if (invoice.Status == StripeConstants.InvoiceStatus.Paid || invoice.AttemptCount <= 1 || !ShouldAttemptToPayInvoice(invoice)) { return; } - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); - // attempt count 4 = 11 days after initial failure - if (invoice.AttemptCount <= 3 || - !subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + if (invoice.Parent?.SubscriptionDetails != null) { - await _stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); + var subscription = await _stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); + // attempt count 4 = 11 days after initial failure + if (invoice.AttemptCount <= 3 || + !subscription.Items.Any(i => i.Price.Id is IStripeEventUtilityService.PremiumPlanId or IStripeEventUtilityService.PremiumPlanIdAppStore)) + { + await _stripeEventUtilityService.AttemptToPayInvoiceAsync(invoice); + } } } @@ -44,9 +48,9 @@ public class PaymentFailedHandler : IPaymentFailedHandler invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null + Parent.SubscriptionDetails: not null }; } diff --git a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs index a10fa4b3d6..443227f7bf 100644 --- a/src/Billing/Services/Implementations/PaymentSucceededHandler.cs +++ b/src/Billing/Services/Implementations/PaymentSucceededHandler.cs @@ -1,7 +1,9 @@ using Bit.Billing.Constants; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Repositories; using Bit.Core.Services; @@ -29,12 +31,17 @@ public class PaymentSucceededHandler( public async Task HandleAsync(Event parsedEvent) { var invoice = await stripeEventService.GetInvoice(parsedEvent, true); - if (!invoice.Paid || invoice.BillingReason != "subscription_create") + if (invoice.Status != StripeConstants.InvoiceStatus.Paid || invoice.BillingReason != "subscription_create") { return; } - var subscription = await stripeFacade.GetSubscription(invoice.SubscriptionId); + if (invoice.Parent?.SubscriptionDetails == null) + { + return; + } + + var subscription = await stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); if (subscription?.Status != StripeSubscriptionStatus.Active) { return; @@ -96,7 +103,7 @@ public class PaymentSucceededHandler( return; } - await organizationEnableCommand.EnableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await organizationEnableCommand.EnableAsync(organizationId.Value, subscription.GetCurrentPeriodEnd()); organization = await organizationRepository.GetByIdAsync(organization.Id); await pushNotificationAdapter.NotifyEnabledChangedAsync(organization!); } @@ -107,7 +114,7 @@ public class PaymentSucceededHandler( return; } - await userService.EnablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await userService.EnablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); } } } diff --git a/src/Billing/Services/Implementations/ProviderEventService.cs b/src/Billing/Services/Implementations/ProviderEventService.cs index 12716c5aa2..79c85cb48f 100644 --- a/src/Billing/Services/Implementations/ProviderEventService.cs +++ b/src/Billing/Services/Implementations/ProviderEventService.cs @@ -28,9 +28,14 @@ public class ProviderEventService( return; } - var invoice = await stripeEventService.GetInvoice(parsedEvent); + var invoice = await stripeEventService.GetInvoice(parsedEvent, true, ["discounts"]); - var metadata = (await stripeFacade.GetSubscription(invoice.SubscriptionId)).Metadata ?? new Dictionary(); + if (invoice.Parent is not { Type: "subscription_details" }) + { + return; + } + + var metadata = (await stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId)).Metadata ?? new Dictionary(); var hasProviderId = metadata.TryGetValue("providerId", out var providerId); @@ -68,7 +73,9 @@ public class ProviderEventService( var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - var discountedPercentage = (100 - (invoice.Discount?.Coupon?.PercentOff ?? 0)) / 100; + var totalPercentOff = invoice.Discounts?.Sum(discount => discount?.Coupon?.PercentOff ?? 0) ?? 0; + + var discountedPercentage = (100 - totalPercentOff) / 100; var discountedSeatPrice = plan.PasswordManager.ProviderPortalSeatPrice * discountedPercentage; @@ -96,7 +103,9 @@ public class ProviderEventService( var unassignedSeats = providerPlan.SeatMinimum - clientSeats ?? 0; - var discountedPercentage = (100 - (invoice.Discount?.Coupon?.PercentOff ?? 0)) / 100; + var totalPercentOff = invoice.Discounts?.Sum(discount => discount?.Coupon?.PercentOff ?? 0) ?? 0; + + var discountedPercentage = (100 - totalPercentOff) / 100; var discountedSeatPrice = plan.PasswordManager.ProviderPortalSeatPrice * discountedPercentage; diff --git a/src/Billing/Services/Implementations/StripeEventUtilityService.cs b/src/Billing/Services/Implementations/StripeEventUtilityService.cs index 4c96bf977d..49e562de56 100644 --- a/src/Billing/Services/Implementations/StripeEventUtilityService.cs +++ b/src/Billing/Services/Implementations/StripeEventUtilityService.cs @@ -2,6 +2,7 @@ #nullable disable using Bit.Billing.Constants; +using Bit.Core.Billing.Constants; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -87,25 +88,6 @@ public class StripeEventUtilityService : IStripeEventUtilityService /// public async Task<(Guid?, Guid?, Guid?)> GetEntityIdsFromChargeAsync(Charge charge) { - Guid? organizationId = null; - Guid? userId = null; - Guid? providerId = null; - - if (charge.InvoiceId != null) - { - var invoice = await _stripeFacade.GetInvoice(charge.InvoiceId); - if (invoice?.SubscriptionId != null) - { - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); - (organizationId, userId, providerId) = GetIdsFromMetadata(subscription?.Metadata); - } - } - - if (organizationId.HasValue || userId.HasValue || providerId.HasValue) - { - return (organizationId, userId, providerId); - } - var subscriptions = await _stripeFacade.ListSubscriptions(new SubscriptionListOptions { Customer = charge.CustomerId @@ -118,7 +100,7 @@ public class StripeEventUtilityService : IStripeEventUtilityService continue; } - (organizationId, userId, providerId) = GetIdsFromMetadata(subscription.Metadata); + var (organizationId, userId, providerId) = GetIdsFromMetadata(subscription.Metadata); if (organizationId.HasValue || userId.HasValue || providerId.HasValue) { @@ -256,10 +238,10 @@ public class StripeEventUtilityService : IStripeEventUtilityService invoice is { AmountDue: > 0, - Paid: false, + Status: not StripeConstants.InvoiceStatus.Paid, CollectionMethod: "charge_automatically", BillingReason: "subscription_cycle" or "automatic_pending_invoice_item_invoice", - SubscriptionId: not null + Parent.SubscriptionDetails: not null }; private async Task AttemptToPayInvoiceWithBraintreeAsync(Invoice invoice, Customer customer) @@ -272,7 +254,13 @@ public class StripeEventUtilityService : IStripeEventUtilityService return false; } - var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); + if (invoice.Parent?.SubscriptionDetails == null) + { + _logger.LogWarning("Invoice parent was not a subscription."); + return false; + } + + var subscription = await _stripeFacade.GetSubscription(invoice.Parent.SubscriptionDetails.SubscriptionId); var (organizationId, userId, providerId) = GetIdsFromMetadata(subscription?.Metadata); if (!organizationId.HasValue && !userId.HasValue && !providerId.HasValue) { diff --git a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs index 465da86c3f..c204cc5026 100644 --- a/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionDeletedHandler.cs @@ -1,6 +1,11 @@ using Bit.Billing.Constants; +using Bit.Billing.Jobs; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Extensions; using Bit.Core.Services; +using Quartz; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -10,17 +15,26 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler private readonly IUserService _userService; private readonly IStripeEventUtilityService _stripeEventUtilityService; private readonly IOrganizationDisableCommand _organizationDisableCommand; + private readonly IProviderRepository _providerRepository; + private readonly IProviderService _providerService; + private readonly ISchedulerFactory _schedulerFactory; public SubscriptionDeletedHandler( IStripeEventService stripeEventService, IUserService userService, IStripeEventUtilityService stripeEventUtilityService, - IOrganizationDisableCommand organizationDisableCommand) + IOrganizationDisableCommand organizationDisableCommand, + IProviderRepository providerRepository, + IProviderService providerService, + ISchedulerFactory schedulerFactory) { _stripeEventService = stripeEventService; _userService = userService; _stripeEventUtilityService = stripeEventUtilityService; _organizationDisableCommand = organizationDisableCommand; + _providerRepository = providerRepository; + _providerService = providerService; + _schedulerFactory = schedulerFactory; } /// @@ -50,11 +64,40 @@ public class SubscriptionDeletedHandler : ISubscriptionDeletedHandler return; } - await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.GetCurrentPeriodEnd()); + } + else if (providerId.HasValue) + { + var provider = await _providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = false; + await _providerService.UpdateAsync(provider); + + await QueueProviderOrganizationDisableJobAsync(providerId.Value, subscription.GetCurrentPeriodEnd()); + } } else if (userId.HasValue) { - await _userService.DisablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.DisablePremiumAsync(userId.Value, subscription.GetCurrentPeriodEnd()); } } + + private async Task QueueProviderOrganizationDisableJobAsync(Guid providerId, DateTime? expirationDate) + { + var scheduler = await _schedulerFactory.GetScheduler(); + + var job = JobBuilder.Create() + .WithIdentity($"disable-provider-orgs-{providerId}", "provider-management") + .UsingJobData("providerId", providerId.ToString()) + .UsingJobData("expirationDate", expirationDate?.ToString("O")) + .Build(); + + var trigger = TriggerBuilder.Create() + .WithIdentity($"disable-trigger-{providerId}", "provider-management") + .StartNow() + .Build(); + + await scheduler.ScheduleJob(job, trigger); + } } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index 10630f78f4..81aeb460c2 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -5,6 +5,8 @@ using Bit.Core; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; @@ -82,12 +84,14 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice", "test_clock"]); var (organizationId, userId, providerId) = _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); + var currentPeriodEnd = subscription.GetCurrentPeriodEnd(); + switch (subscription.Status) { case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired when organizationId.HasValue: { - await _organizationDisableCommand.DisableAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); if (subscription.Status == StripeSubscriptionStatus.Unpaid && subscription.LatestInvoice is { BillingReason: "subscription_cycle" or "subscription_create" }) { @@ -114,7 +118,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler await VoidOpenInvoices(subscription.Id); } - await _userService.DisablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); break; } @@ -154,7 +158,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler { if (userId.HasValue) { - await _userService.EnablePremiumAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd); } break; } @@ -162,17 +166,17 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler if (organizationId.HasValue) { - await _organizationService.UpdateExpirationDateAsync(organizationId.Value, subscription.CurrentPeriodEnd); - if (_stripeEventUtilityService.IsSponsoredSubscription(subscription)) + await _organizationService.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd); + if (_stripeEventUtilityService.IsSponsoredSubscription(subscription) && currentPeriodEnd.HasValue) { - await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, subscription.CurrentPeriodEnd); + await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd.Value); } await RemovePasswordManagerCouponIfRemovingSecretsManagerTrialAsync(parsedEvent, subscription); } else if (userId.HasValue) { - await _userService.UpdatePremiumExpirationAsync(userId.Value, subscription.CurrentPeriodEnd); + await _userService.UpdatePremiumExpirationAsync(userId.Value, currentPeriodEnd); } } @@ -280,9 +284,8 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler ?.Coupon ?.Id == "sm-standalone"; - var subscriptionHasSecretsManagerTrial = subscription.Discount - ?.Coupon - ?.Id == "sm-standalone"; + var subscriptionHasSecretsManagerTrial = subscription.Discounts.Select(discount => discount.Coupon.Id) + .Contains(StripeConstants.CouponIDs.SecretsManagerStandalone); if (customerHasSecretsManagerTrial) { diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index e5675f7c0a..1db469a4e2 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below - -#nullable disable - +using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Repositories; @@ -10,14 +7,20 @@ using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; +using Bit.Core.Entities; +using Bit.Core.Models.Mail.UpdatedInvoiceIncoming; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; +using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Repositories; using Bit.Core.Services; using Stripe; using Event = Stripe.Event; +using Plan = Bit.Core.Models.StaticStore.Plan; namespace Bit.Billing.Services.Implementations; +using static StripeConstants; + public class UpcomingInvoiceHandler( IGetPaymentMethodQuery getPaymentMethodQuery, ILogger logger, @@ -29,138 +32,412 @@ public class UpcomingInvoiceHandler( IStripeEventService stripeEventService, IStripeEventUtilityService stripeEventUtilityService, IUserRepository userRepository, - IValidateSponsorshipCommand validateSponsorshipCommand) + IValidateSponsorshipCommand validateSponsorshipCommand, + IMailer mailer, + IFeatureService featureService) : IUpcomingInvoiceHandler { public async Task HandleAsync(Event parsedEvent) { var invoice = await stripeEventService.GetInvoice(parsedEvent); - if (string.IsNullOrEmpty(invoice.SubscriptionId)) + var customer = + await stripeFacade.GetCustomer(invoice.CustomerId, + new CustomerGetOptions { Expand = ["subscriptions", "tax", "tax_ids"] }); + + var subscription = customer.Subscriptions.FirstOrDefault(); + + if (subscription == null) { - logger.LogInformation("Received 'invoice.upcoming' Event with ID '{eventId}' that did not include a Subscription ID", parsedEvent.Id); return; } - var subscription = await stripeFacade.GetSubscription(invoice.SubscriptionId, new SubscriptionGetOptions - { - Expand = ["customer.tax", "customer.tax_ids"] - }); - var (organizationId, userId, providerId) = stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); if (organizationId.HasValue) { - var organization = await organizationRepository.GetByIdAsync(organizationId.Value); - - if (organization == null) - { - return; - } - - await AlignOrganizationTaxConcernsAsync(organization, subscription, parsedEvent.Id); - - var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - - if (!plan.IsAnnual) - { - return; - } - - if (stripeEventUtilityService.IsSponsoredSubscription(subscription)) - { - var sponsorshipIsValid = await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value); - - if (!sponsorshipIsValid) - { - /* - * If the sponsorship is invalid, then the subscription was updated to use the regular families plan - * price. Given that this is the case, we need the new invoice amount - */ - invoice = await stripeFacade.GetInvoice(subscription.LatestInvoiceId); - } - } - - await SendUpcomingInvoiceEmailsAsync(new List { organization.BillingEmail }, invoice); - - /* - * TODO: https://bitwarden.atlassian.net/browse/PM-4862 - * Disabling this as part of a hot fix. It needs to check whether the organization - * belongs to a Reseller provider and only send an email to the organization owners if it does. - * It also requires a new email template as the current one contains too much billing information. - */ - - // var ownerEmails = await _organizationRepository.GetOwnerEmailAddressesById(organization.Id); - - // await SendEmails(ownerEmails); + await HandleOrganizationUpcomingInvoiceAsync( + organizationId.Value, + parsedEvent, + invoice, + customer, + subscription); } else if (userId.HasValue) { - var user = await userRepository.GetByIdAsync(userId.Value); - - if (user == null) - { - return; - } - - if (!subscription.AutomaticTax.Enabled && subscription.Customer.HasRecognizedTaxLocation()) - { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", - user.Id, - parsedEvent.Id); - } - } - - if (user.Premium) - { - await SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice); - } + await HandlePremiumUsersUpcomingInvoiceAsync( + userId.Value, + parsedEvent, + invoice, + customer, + subscription); } else if (providerId.HasValue) { - var provider = await providerRepository.GetByIdAsync(providerId.Value); + await HandleProviderUpcomingInvoiceAsync( + providerId.Value, + parsedEvent, + invoice, + customer, + subscription); + } + } - if (provider == null) + #region Organizations + + private async Task HandleOrganizationUpcomingInvoiceAsync( + Guid organizationId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) + { + var organization = await organizationRepository.GetByIdAsync(organizationId); + + if (organization == null) + { + logger.LogWarning("Could not find Organization ({OrganizationID}) for '{EventType}' event ({EventID})", + organizationId, @event.Type, @event.Id); + return; + } + + await AlignOrganizationTaxConcernsAsync(organization, subscription, customer, @event.Id); + + var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + var milestone3 = featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3); + + await AlignOrganizationSubscriptionConcernsAsync( + organization, + @event, + subscription, + plan, + milestone3); + + // Don't send the upcoming invoice email unless the organization's on an annual plan. + if (!plan.IsAnnual) + { + return; + } + + if (stripeEventUtilityService.IsSponsoredSubscription(subscription)) + { + var sponsorshipIsValid = + await validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId); + + if (!sponsorshipIsValid) { + /* + * If the sponsorship is invalid, then the subscription was updated to use the regular families plan + * price. Given that this is the case, we need the new invoice amount + */ + invoice = await stripeFacade.GetInvoice(subscription.LatestInvoiceId); + } + } + + await (milestone3 + ? SendUpdatedUpcomingInvoiceEmailsAsync([organization.BillingEmail]) + : SendUpcomingInvoiceEmailsAsync([organization.BillingEmail], invoice)); + } + + private async Task AlignOrganizationTaxConcernsAsync( + Organization organization, + Subscription subscription, + Customer customer, + string eventId) + { + var nonUSBusinessUse = + organization.PlanType.GetProductTier() != ProductTierType.Families && + customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; + + if (nonUSBusinessUse && customer.TaxExempt != TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + + if (!subscription.AutomaticTax.Enabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", + organization.Id, + eventId); + } + } + } + + private async Task AlignOrganizationSubscriptionConcernsAsync( + Organization organization, + Event @event, + Subscription subscription, + Plan plan, + bool milestone3) + { + if (milestone3 && plan.Type == PlanType.FamiliesAnnually2019) + { + var passwordManagerItem = + subscription.Items.FirstOrDefault(item => item.Price.Id == plan.PasswordManager.StripePlanId); + + if (passwordManagerItem == null) + { + logger.LogWarning("Could not find Organization's ({OrganizationId}) password manager item while processing '{EventType}' event ({EventID})", + organization.Id, @event.Type, @event.Id); return; } - await AlignProviderTaxConcernsAsync(provider, subscription, parsedEvent.Id); + var families = await pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually); - await SendProviderUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice, subscription, providerId.Value); + organization.PlanType = families.Type; + organization.Plan = families.Name; + organization.UsersGetPremium = families.UsersGetPremium; + organization.Seats = families.PasswordManager.BaseSeats; + + var options = new SubscriptionUpdateOptions + { + Items = + [ + new SubscriptionItemOptions + { + Id = passwordManagerItem.Id, Price = families.PasswordManager.StripePlanId + } + ], + Discounts = + [ + new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone3SubscriptionDiscount } + ], + ProrationBehavior = ProrationBehavior.None + }; + + var premiumAccessAddOnItem = subscription.Items.FirstOrDefault(item => + item.Price.Id == plan.PasswordManager.StripePremiumAccessPlanId); + + if (premiumAccessAddOnItem != null) + { + options.Items.Add(new SubscriptionItemOptions + { + Id = premiumAccessAddOnItem.Id, + Deleted = true + }); + } + + try + { + await organizationRepository.ReplaceAsync(organization); + await stripeFacade.UpdateSubscription(subscription.Id, options); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to align subscription concerns for Organization ({OrganizationID}) while processing '{EventType}' event ({EventID})", + organization.Id, + @event.Type, + @event.Id); + } } } - private async Task SendUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice) + #endregion + + #region Premium Users + + private async Task HandlePremiumUsersUpcomingInvoiceAsync( + Guid userId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) { - var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); + var user = await userRepository.GetByIdAsync(userId); - var items = invoice.Lines.Select(i => i.Description).ToList(); - - if (invoice.NextPaymentAttempt.HasValue && invoice.AmountDue > 0) + if (user == null) { - await mailService.SendInvoiceUpcoming( - validEmails, - invoice.AmountDue / 100M, - invoice.NextPaymentAttempt.Value, - items, - true); + logger.LogWarning("Could not find User ({UserID}) for '{EventType}' event ({EventID})", + userId, @event.Type, @event.Id); + return; + } + + await AlignPremiumUsersTaxConcernsAsync(user, @event, customer, subscription); + + var milestone2Feature = featureService.IsEnabled(FeatureFlagKeys.PM23341_Milestone_2); + if (milestone2Feature) + { + await AlignPremiumUsersSubscriptionConcernsAsync(user, @event, subscription); + } + + if (user.Premium) + { + await (milestone2Feature + ? SendUpdatedUpcomingInvoiceEmailsAsync(new List { user.Email }) + : SendUpcomingInvoiceEmailsAsync(new List { user.Email }, invoice)); } } - private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice, Subscription subscription, Guid providerId) + private async Task AlignPremiumUsersTaxConcernsAsync( + User user, + Event @event, + Customer customer, + Subscription subscription) + { + if (!subscription.AutomaticTax.Enabled && customer.HasRecognizedTaxLocation()) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set user's ({UserID}) subscription to automatic tax while processing event with ID {EventID}", + user.Id, + @event.Id); + } + } + } + + private async Task AlignPremiumUsersSubscriptionConcernsAsync( + User user, + Event @event, + Subscription subscription) + { + var premiumItem = subscription.Items.FirstOrDefault(i => i.Price.Id == Prices.PremiumAnnually); + + if (premiumItem == null) + { + logger.LogWarning("Could not find User's ({UserID}) premium subscription item while processing '{EventType}' event ({EventID})", + user.Id, @event.Type, @event.Id); + return; + } + + try + { + var plan = await pricingClient.GetAvailablePremiumPlan(); + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + Items = + [ + new SubscriptionItemOptions { Id = premiumItem.Id, Price = plan.Seat.StripePriceId } + ], + Discounts = + [ + new SubscriptionDiscountOptions { Coupon = CouponIDs.Milestone2SubscriptionDiscount } + ], + ProrationBehavior = ProrationBehavior.None + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to update user's ({UserID}) subscription price id while processing event with ID {EventID}", + user.Id, + @event.Id); + } + } + + #endregion + + #region Providers + + private async Task HandleProviderUpcomingInvoiceAsync( + Guid providerId, + Event @event, + Invoice invoice, + Customer customer, + Subscription subscription) + { + var provider = await providerRepository.GetByIdAsync(providerId); + + if (provider == null) + { + logger.LogWarning("Could not find Provider ({ProviderID}) for '{EventType}' event ({EventID})", + providerId, @event.Type, @event.Id); + return; + } + + await AlignProviderTaxConcernsAsync(provider, subscription, customer, @event.Id); + + if (!string.IsNullOrEmpty(provider.BillingEmail)) + { + await SendProviderUpcomingInvoiceEmailsAsync(new List { provider.BillingEmail }, invoice, subscription, providerId); + } + } + + private async Task AlignProviderTaxConcernsAsync( + Provider provider, + Subscription subscription, + Customer customer, + string eventId) + { + if (customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && + customer.TaxExempt != TaxExempt.Reverse) + { + try + { + await stripeFacade.UpdateCustomer(subscription.CustomerId, + new CustomerUpdateOptions { TaxExempt = TaxExempt.Reverse }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", + provider.Id, + eventId); + } + } + + if (!subscription.AutomaticTax.Enabled) + { + try + { + await stripeFacade.UpdateSubscription(subscription.Id, + new SubscriptionUpdateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } + }); + } + catch (Exception exception) + { + logger.LogError( + exception, + "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", + provider.Id, + eventId); + } + } + } + + private async Task SendProviderUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice, + Subscription subscription, Guid providerId) { var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); @@ -196,94 +473,37 @@ public class UpcomingInvoiceHandler( } } - private async Task AlignOrganizationTaxConcernsAsync( - Organization organization, - Subscription subscription, - string eventId) + #endregion + + #region Shared + + private async Task SendUpcomingInvoiceEmailsAsync(IEnumerable emails, Invoice invoice) { - var nonUSBusinessUse = - organization.PlanType.GetProductTier() != ProductTierType.Families && - subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates; + var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); - if (nonUSBusinessUse && subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) - { - try - { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set organization's ({OrganizationID}) to reverse tax exemption while processing event with ID {EventID}", - organization.Id, - eventId); - } - } + var items = invoice.Lines.Select(i => i.Description).ToList(); - if (!subscription.AutomaticTax.Enabled) + if (invoice is { NextPaymentAttempt: not null, AmountDue: > 0 }) { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set organization's ({OrganizationID}) subscription to automatic tax while processing event with ID {EventID}", - organization.Id, - eventId); - } + await mailService.SendInvoiceUpcoming( + validEmails, + invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, + items, + true); } } - private async Task AlignProviderTaxConcernsAsync( - Provider provider, - Subscription subscription, - string eventId) + private async Task SendUpdatedUpcomingInvoiceEmailsAsync(IEnumerable emails) { - if (subscription.Customer.Address.Country != Core.Constants.CountryAbbreviations.UnitedStates && - subscription.Customer.TaxExempt != StripeConstants.TaxExempt.Reverse) + var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); + var updatedUpcomingEmail = new UpdatedInvoiceUpcomingMail { - try - { - await stripeFacade.UpdateCustomer(subscription.CustomerId, - new CustomerUpdateOptions { TaxExempt = StripeConstants.TaxExempt.Reverse }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set provider's ({ProviderID}) to reverse tax exemption while processing event with ID {EventID}", - provider.Id, - eventId); - } - } - - if (!subscription.AutomaticTax.Enabled) - { - try - { - await stripeFacade.UpdateSubscription(subscription.Id, - new SubscriptionUpdateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true } - }); - } - catch (Exception exception) - { - logger.LogError( - exception, - "Failed to set provider's ({ProviderID}) subscription to automatic tax while processing event with ID {EventID}", - provider.Id, - eventId); - } - } + ToEmails = validEmails, + View = new UpdatedInvoiceUpcomingView() + }; + await mailer.SendEmail(updatedUpcomingEmail); } + + #endregion } diff --git a/src/Billing/Startup.cs b/src/Billing/Startup.cs index 5b464d5ef6..cdb9700ad5 100644 --- a/src/Billing/Startup.cs +++ b/src/Billing/Startup.cs @@ -51,9 +51,6 @@ public class Startup // Repositories services.AddDatabaseRepositories(globalSettings); - // BitPay Client - services.AddSingleton(); - // PayPal IPN Client services.AddHttpClient(); diff --git a/src/Billing/appsettings.json b/src/Billing/appsettings.json index 0074b5aafe..a2d6acd0a1 100644 --- a/src/Billing/appsettings.json +++ b/src/Billing/appsettings.json @@ -57,9 +57,7 @@ "billingSettings": { "jobsKey": "SECRET", "stripeWebhookKey": "SECRET", - "stripeWebhookSecret": "SECRET", - "stripeWebhookSecret20231016": "SECRET", - "stripeWebhookSecret20240620": "SECRET", + "stripeWebhookSecret20250827Basil": "SECRET", "bitPayWebhookKey": "SECRET", "appleWebhookKey": "SECRET", "payPal": { @@ -80,7 +78,13 @@ "onyx": { "apiKey": "SECRET", "baseUrl": "https://cloud.onyx.app/api", - "personaId": 7 - } + "path": "/chat/send-message-simple-api", + "useAnswerWithCitationModels": true, + "personaId": 7, + "searchSettings": { + "runSearch": "always", + "realTime": true + } + } } } diff --git a/src/Core/AdminConsole/Entities/Event.cs b/src/Core/AdminConsole/Entities/Event.cs index 38d8f07b53..e2868c1915 100644 --- a/src/Core/AdminConsole/Entities/Event.cs +++ b/src/Core/AdminConsole/Entities/Event.cs @@ -34,6 +34,7 @@ public class Event : ITableObject, IEvent SecretId = e.SecretId; ProjectId = e.ProjectId; ServiceAccountId = e.ServiceAccountId; + GrantedServiceAccountId = e.GrantedServiceAccountId; } public Guid Id { get; set; } @@ -59,7 +60,7 @@ public class Event : ITableObject, IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } - + public Guid? GrantedServiceAccountId { get; set; } public void SetNewId() { Id = CoreHelpers.GenerateComb(); diff --git a/src/Core/AdminConsole/Entities/Organization.cs b/src/Core/AdminConsole/Entities/Organization.cs index 7933990e74..73aa162f22 100644 --- a/src/Core/AdminConsole/Entities/Organization.cs +++ b/src/Core/AdminConsole/Entities/Organization.cs @@ -129,6 +129,11 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable /// public bool SyncSeats { get; set; } + /// + /// If set to true, user accounts created within the organization are automatically confirmed without requiring additional verification steps. + /// + public bool UseAutomaticUserConfirmation { get; set; } + public void SetNewId() { if (Id == default(Guid)) @@ -328,5 +333,6 @@ public class Organization : ITableObject, IStorableSubscriber, IRevisable UseRiskInsights = license.UseRiskInsights; UseOrganizationDomains = license.UseOrganizationDomains; UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation; } } diff --git a/src/Core/AdminConsole/Enums/EventType.cs b/src/Core/AdminConsole/Enums/EventType.cs index 81501fd6ec..8073938fc5 100644 --- a/src/Core/AdminConsole/Enums/EventType.cs +++ b/src/Core/AdminConsole/Enums/EventType.cs @@ -70,8 +70,8 @@ public enum EventType : int Organization_EnabledKeyConnector = 1606, Organization_DisabledKeyConnector = 1607, Organization_SponsorshipsSynced = 1608, - [Obsolete("Use other specific Organization_CollectionManagement events instead")] - Organization_CollectionManagement_Updated = 1609, // TODO: Will be removed in PM-25315 + [Obsolete("Kept for historical data. Use specific Organization_CollectionManagement events instead.")] + Organization_CollectionManagement_Updated = 1609, Organization_CollectionManagement_LimitCollectionCreationEnabled = 1610, Organization_CollectionManagement_LimitCollectionCreationDisabled = 1611, Organization_CollectionManagement_LimitCollectionDeletionEnabled = 1612, @@ -109,4 +109,11 @@ public enum EventType : int Project_Created = 2201, Project_Edited = 2202, Project_Deleted = 2203, + + ServiceAccount_UserAdded = 2300, + ServiceAccount_UserRemoved = 2301, + ServiceAccount_GroupAdded = 2302, + ServiceAccount_GroupRemoved = 2303, + ServiceAccount_Created = 2304, + ServiceAccount_Deleted = 2305, } diff --git a/src/Core/AdminConsole/Enums/IntegrationType.cs b/src/Core/AdminConsole/Enums/IntegrationType.cs index 34edc71fbe..84e4de94e9 100644 --- a/src/Core/AdminConsole/Enums/IntegrationType.cs +++ b/src/Core/AdminConsole/Enums/IntegrationType.cs @@ -7,7 +7,8 @@ public enum IntegrationType : int Slack = 3, Webhook = 4, Hec = 5, - Datadog = 6 + Datadog = 6, + Teams = 7 } public static class IntegrationTypeExtensions @@ -24,6 +25,8 @@ public static class IntegrationTypeExtensions return "hec"; case IntegrationType.Datadog: return "datadog"; + case IntegrationType.Teams: + return "teams"; default: throw new ArgumentOutOfRangeException(nameof(type), $"Unsupported integration type: {type}"); } diff --git a/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs b/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs new file mode 100644 index 0000000000..78a7bc6d63 --- /dev/null +++ b/src/Core/AdminConsole/Enums/OrganizationIntegrationStatus.cs @@ -0,0 +1,10 @@ +namespace Bit.Api.AdminConsole.Models.Response.Organizations; + +public enum OrganizationIntegrationStatus : int +{ + NotApplicable, + Invalid, + Initiated, + InProgress, + Completed +} diff --git a/src/Core/AdminConsole/Enums/PolicyType.cs b/src/Core/AdminConsole/Enums/PolicyType.cs index 452fbcce01..09fa4ec955 100644 --- a/src/Core/AdminConsole/Enums/PolicyType.cs +++ b/src/Core/AdminConsole/Enums/PolicyType.cs @@ -20,6 +20,7 @@ public enum PolicyType : byte RestrictedItemTypesPolicy = 15, UriMatchDefaults = 16, AutotypeDefaultSetting = 17, + AutomaticUserConfirmation = 18, } public static class PolicyTypeExtensions @@ -44,12 +45,13 @@ public static class PolicyTypeExtensions PolicyType.MaximumVaultTimeout => "Vault timeout", PolicyType.DisablePersonalVaultExport => "Remove individual vault export", PolicyType.ActivateAutofill => "Active auto-fill", - PolicyType.AutomaticAppLogIn => "Automatically log in users for allowed applications", + PolicyType.AutomaticAppLogIn => "Automatic login with SSO", PolicyType.FreeFamiliesSponsorshipPolicy => "Remove Free Bitwarden Families sponsorship", PolicyType.RemoveUnlockWithPin => "Remove unlock with PIN", PolicyType.RestrictedItemTypesPolicy => "Restricted item types", PolicyType.UriMatchDefaults => "URI match defaults", PolicyType.AutotypeDefaultSetting => "Autotype default setting", + PolicyType.AutomaticUserConfirmation => "Automatically confirm invited users", }; } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs index 7b2dd1343e..7df1459941 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IEventListenerConfiguration.cs @@ -5,4 +5,6 @@ public interface IEventListenerConfiguration public string EventQueueName { get; } public string EventSubscriptionName { get; } public string EventTopicName { get; } + public int EventPrefetchCount { get; } + public int EventMaxConcurrentCalls { get; } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs index 322a1cd952..30401bb072 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationListenerConfiguration.cs @@ -10,6 +10,8 @@ public interface IIntegrationListenerConfiguration : IEventListenerConfiguration public string IntegrationSubscriptionName { get; } public string IntegrationTopicName { get; } public int MaxRetries { get; } + public int IntegrationPrefetchCount { get; } + public int IntegrationMaxConcurrentCalls { get; } public string RoutingKey { diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs index 7a0962d89a..5b6bfe2e53 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs @@ -6,6 +6,7 @@ public interface IIntegrationMessage { IntegrationType IntegrationType { get; } string MessageId { get; set; } + string? OrganizationId { get; set; } int RetryCount { get; } DateTime? DelayUntilDate { get; } void ApplyRetry(DateTime? handlerDelayUntilDate); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs index 11a5229f8c..b0fc2161ba 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs @@ -7,6 +7,7 @@ public class IntegrationMessage : IIntegrationMessage { public IntegrationType IntegrationType { get; set; } public required string MessageId { get; set; } + public string? OrganizationId { get; set; } public required string RenderedTemplate { get; set; } public int RetryCount { get; set; } = 0; public DateTime? DelayUntilDate { get; set; } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs new file mode 100644 index 0000000000..3b29bbebb4 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationOAuthState.cs @@ -0,0 +1,71 @@ +using System.Security.Cryptography; +using System.Text; +using Bit.Core.AdminConsole.Entities; + +namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; + +public class IntegrationOAuthState +{ + private const int _orgHashLength = 12; + private static readonly TimeSpan _maxAge = TimeSpan.FromMinutes(20); + + public Guid IntegrationId { get; } + private DateTimeOffset Issued { get; } + private string OrganizationIdHash { get; } + + private IntegrationOAuthState(Guid integrationId, string organizationIdHash, DateTimeOffset issued) + { + IntegrationId = integrationId; + OrganizationIdHash = organizationIdHash; + Issued = issued; + } + + public static IntegrationOAuthState FromIntegration(OrganizationIntegration integration, TimeProvider timeProvider) + { + var integrationId = integration.Id; + var issuedUtc = timeProvider.GetUtcNow(); + var organizationIdHash = ComputeOrgHash(integration.OrganizationId, issuedUtc.ToUnixTimeSeconds()); + + return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc); + } + + public static IntegrationOAuthState? FromString(string state, TimeProvider timeProvider) + { + if (string.IsNullOrWhiteSpace(state)) return null; + + var parts = state.Split('.'); + if (parts.Length != 3) return null; + + // Verify timestamp + if (!long.TryParse(parts[2], out var unixSeconds)) return null; + + var issuedUtc = DateTimeOffset.FromUnixTimeSeconds(unixSeconds); + var now = timeProvider.GetUtcNow(); + var age = now - issuedUtc; + + if (age > _maxAge) return null; + + // Parse integration id and store org + if (!Guid.TryParse(parts[0], out var integrationId)) return null; + var organizationIdHash = parts[1]; + + return new IntegrationOAuthState(integrationId, organizationIdHash, issuedUtc); + } + + public bool ValidateOrg(Guid orgId) + { + var expected = ComputeOrgHash(orgId, Issued.ToUnixTimeSeconds()); + return expected == OrganizationIdHash; + } + + public override string ToString() + { + return $"{IntegrationId}.{OrganizationIdHash}.{Issued.ToUnixTimeSeconds()}"; + } + + private static string ComputeOrgHash(Guid orgId, long timestamp) + { + var bytes = SHA256.HashData(Encoding.UTF8.GetBytes($"{orgId:N}:{timestamp}")); + return Convert.ToHexString(bytes)[.._orgHashLength]; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs index 79a30c3a02..fe33c45156 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationTemplateContext.cs @@ -23,7 +23,17 @@ public class IntegrationTemplateContext(EventMessage eventMessage) public Guid? CollectionId => Event.CollectionId; public Guid? GroupId => Event.GroupId; public Guid? PolicyId => Event.PolicyId; + public Guid? IdempotencyId => Event.IdempotencyId; + public Guid? ProviderId => Event.ProviderId; + public Guid? ProviderUserId => Event.ProviderUserId; + public Guid? ProviderOrganizationId => Event.ProviderOrganizationId; + public Guid? InstallationId => Event.InstallationId; + public Guid? SecretId => Event.SecretId; + public Guid? ProjectId => Event.ProjectId; + public Guid? ServiceAccountId => Event.ServiceAccountId; + public Guid? GrantedServiceAccountId => Event.GrantedServiceAccountId; + public string DateIso8601 => Date.ToString("o"); public string EventMessage => JsonSerializer.Serialize(Event); public User? User { get; set; } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs index 662bb8241e..40eb2b3e77 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/ListenerConfiguration.cs @@ -25,4 +25,24 @@ public abstract class ListenerConfiguration { get => _globalSettings.EventLogging.AzureServiceBus.IntegrationTopicName; } + + public int EventPrefetchCount + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount; + } + + public int EventMaxConcurrentCalls + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls; + } + + public int IntegrationPrefetchCount + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultPrefetchCount; + } + + public int IntegrationMaxConcurrentCalls + { + get => _globalSettings.EventLogging.AzureServiceBus.DefaultMaxConcurrentCalls; + } } diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegration.cs new file mode 100644 index 0000000000..8390022839 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegration.cs @@ -0,0 +1,12 @@ +using Bit.Core.Models.Teams; + +namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; + +public record TeamsIntegration( + string TenantId, + IReadOnlyList Teams, + string? ChannelId = null, + Uri? ServiceUrl = null) +{ + public bool IsCompleted => !string.IsNullOrEmpty(ChannelId) && ServiceUrl is not null; +} diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs new file mode 100644 index 0000000000..66fe558dff --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsIntegrationConfigurationDetails.cs @@ -0,0 +1,3 @@ +namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; + +public record TeamsIntegrationConfigurationDetails(string ChannelId, Uri ServiceUrl); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs new file mode 100644 index 0000000000..24cf674648 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/TeamsListenerConfiguration.cs @@ -0,0 +1,38 @@ +using Bit.Core.Enums; +using Bit.Core.Settings; + +namespace Bit.Core.AdminConsole.Models.Data.EventIntegrations; + +public class TeamsListenerConfiguration(GlobalSettings globalSettings) : + ListenerConfiguration(globalSettings), IIntegrationListenerConfiguration +{ + public IntegrationType IntegrationType + { + get => IntegrationType.Teams; + } + + public string EventQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsEventsQueueName; + } + + public string IntegrationQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsIntegrationQueueName; + } + + public string IntegrationRetryQueueName + { + get => _globalSettings.EventLogging.RabbitMq.TeamsIntegrationRetryQueueName; + } + + public string EventSubscriptionName + { + get => _globalSettings.EventLogging.AzureServiceBus.TeamsEventSubscriptionName; + } + + public string IntegrationSubscriptionName + { + get => _globalSettings.EventLogging.AzureServiceBus.TeamsIntegrationSubscriptionName; + } +} diff --git a/src/Core/AdminConsole/Models/Data/EventMessage.cs b/src/Core/AdminConsole/Models/Data/EventMessage.cs index b708c5bd56..a29d70c203 100644 --- a/src/Core/AdminConsole/Models/Data/EventMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventMessage.cs @@ -39,4 +39,5 @@ public class EventMessage : IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/EventTableEntity.cs b/src/Core/AdminConsole/Models/Data/EventTableEntity.cs index 4ba50aee0d..1c3023f2cf 100644 --- a/src/Core/AdminConsole/Models/Data/EventTableEntity.cs +++ b/src/Core/AdminConsole/Models/Data/EventTableEntity.cs @@ -37,6 +37,7 @@ public class AzureEvent : ITableEntity public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } public EventTableEntity ToEventTableEntity() { @@ -68,6 +69,7 @@ public class AzureEvent : ITableEntity SecretId = SecretId, ServiceAccountId = ServiceAccountId, ProjectId = ProjectId, + GrantedServiceAccountId = GrantedServiceAccountId }; } } @@ -99,6 +101,7 @@ public class EventTableEntity : IEvent SecretId = e.SecretId; ProjectId = e.ProjectId; ServiceAccountId = e.ServiceAccountId; + GrantedServiceAccountId = e.GrantedServiceAccountId; } public string PartitionKey { get; set; } @@ -127,6 +130,7 @@ public class EventTableEntity : IEvent public Guid? SecretId { get; set; } public Guid? ProjectId { get; set; } public Guid? ServiceAccountId { get; set; } + public Guid? GrantedServiceAccountId { get; set; } public AzureEvent ToAzureEvent() { @@ -157,7 +161,8 @@ public class EventTableEntity : IEvent DomainName = DomainName, SecretId = SecretId, ProjectId = ProjectId, - ServiceAccountId = ServiceAccountId + ServiceAccountId = ServiceAccountId, + GrantedServiceAccountId = GrantedServiceAccountId }; } @@ -232,6 +237,15 @@ public class EventTableEntity : IEvent }); } + if (e.GrantedServiceAccountId.HasValue) + { + entities.Add(new EventTableEntity(e) + { + PartitionKey = pKey, + RowKey = $"GrantedServiceAccountId={e.GrantedServiceAccountId}__Date={dateKey}__Uniquifier={uniquifier}" + }); + } + return entities; } diff --git a/src/Core/AdminConsole/Models/Data/IEvent.cs b/src/Core/AdminConsole/Models/Data/IEvent.cs index 750fb2e2eb..3188c905e4 100644 --- a/src/Core/AdminConsole/Models/Data/IEvent.cs +++ b/src/Core/AdminConsole/Models/Data/IEvent.cs @@ -28,4 +28,5 @@ public interface IEvent Guid? SecretId { get; set; } Guid? ProjectId { get; set; } Guid? ServiceAccountId { get; set; } + Guid? GrantedServiceAccountId { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs new file mode 100644 index 0000000000..820b65dbfd --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/IProfileOrganizationDetails.cs @@ -0,0 +1,56 @@ +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Billing.Enums; + +namespace Bit.Core.AdminConsole.Models.Data; + +/// +/// Interface defining common organization details properties shared between +/// regular organization users and provider organization users for profile endpoints. +/// +public interface IProfileOrganizationDetails +{ + Guid? UserId { get; set; } + Guid OrganizationId { get; set; } + string Name { get; set; } + bool Enabled { get; set; } + PlanType PlanType { get; set; } + bool UsePolicies { get; set; } + bool UseSso { get; set; } + bool UseKeyConnector { get; set; } + bool UseScim { get; set; } + bool UseGroups { get; set; } + bool UseDirectory { get; set; } + bool UseEvents { get; set; } + bool UseTotp { get; set; } + bool Use2fa { get; set; } + bool UseApi { get; set; } + bool UseResetPassword { get; set; } + bool SelfHost { get; set; } + bool UsersGetPremium { get; set; } + bool UseCustomPermissions { get; set; } + bool UseSecretsManager { get; set; } + int? Seats { get; set; } + short? MaxCollections { get; set; } + short? MaxStorageGb { get; set; } + string? Identifier { get; set; } + string? Key { get; set; } + string? ResetPasswordKey { get; set; } + string? PublicKey { get; set; } + string? PrivateKey { get; set; } + string? SsoExternalId { get; set; } + string? Permissions { get; set; } + Guid? ProviderId { get; set; } + string? ProviderName { get; set; } + ProviderType? ProviderType { get; set; } + bool? SsoEnabled { get; set; } + string? SsoConfig { get; set; } + bool UsePasswordManager { get; set; } + bool LimitCollectionCreation { get; set; } + bool LimitCollectionDeletion { get; set; } + bool AllowAdminAccessToAllCollectionItems { get; set; } + bool UseRiskInsights { get; set; } + bool LimitItemDeletion { get; set; } + bool UseAdminSponsoredFamilies { get; set; } + bool UseOrganizationDomains { get; set; } + bool UseAutomaticUserConfirmation { get; set; } +} diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs index ae91f204e3..3c02a4f50b 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationAbility.cs @@ -28,6 +28,7 @@ public class OrganizationAbility UseRiskInsights = organization.UseRiskInsights; UseOrganizationDomains = organization.UseOrganizationDomains; UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies; + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation; } public Guid Id { get; set; } @@ -49,4 +50,5 @@ public class OrganizationAbility public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs index b7e573c4e6..8d30bfc250 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/OrganizationUsers/OrganizationUserOrganizationDetails.cs @@ -1,20 +1,18 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.Billing.Enums; using Bit.Core.Utilities; namespace Bit.Core.Models.Data.Organizations.OrganizationUsers; -public class OrganizationUserOrganizationDetails +public class OrganizationUserOrganizationDetails : IProfileOrganizationDetails { public Guid OrganizationId { get; set; } public Guid? UserId { get; set; } public Guid OrganizationUserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } + public string Name { get; set; } = null!; public bool UsePolicies { get; set; } public bool UseSso { get; set; } public bool UseKeyConnector { get; set; } @@ -33,24 +31,24 @@ public class OrganizationUserOrganizationDetails public int? Seats { get; set; } public short? MaxCollections { get; set; } public short? MaxStorageGb { get; set; } - public string Key { get; set; } + public string? Key { get; set; } public Enums.OrganizationUserStatusType Status { get; set; } public Enums.OrganizationUserType Type { get; set; } public bool Enabled { get; set; } public PlanType PlanType { get; set; } - public string SsoExternalId { get; set; } - public string Identifier { get; set; } - public string Permissions { get; set; } - public string ResetPasswordKey { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public string? SsoExternalId { get; set; } + public string? Identifier { get; set; } + public string? Permissions { get; set; } + public string? ResetPasswordKey { get; set; } + public string? PublicKey { get; set; } + public string? PrivateKey { get; set; } public Guid? ProviderId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } + public string? ProviderName { get; set; } public ProviderType? ProviderType { get; set; } - public string FamilySponsorshipFriendlyName { get; set; } + public string? FamilySponsorshipFriendlyName { get; set; } public bool? SsoEnabled { get; set; } - public string SsoConfig { get; set; } + public string? SsoConfig { get; set; } public DateTime? FamilySponsorshipLastSyncDate { get; set; } public DateTime? FamilySponsorshipValidUntil { get; set; } public bool? FamilySponsorshipToDelete { get; set; } @@ -66,4 +64,5 @@ public class OrganizationUserOrganizationDetails public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } public bool? IsAdminInitiated { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } } diff --git a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs index 04281d098e..0d48f5cfa9 100644 --- a/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs +++ b/src/Core/AdminConsole/Models/Data/Provider/ProviderUserOrganizationDetails.cs @@ -1,19 +1,16 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Billing.Enums; using Bit.Core.Utilities; namespace Bit.Core.AdminConsole.Models.Data.Provider; -public class ProviderUserOrganizationDetails +public class ProviderUserOrganizationDetails : IProfileOrganizationDetails { public Guid OrganizationId { get; set; } public Guid? UserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string Name { get; set; } + public string Name { get; set; } = null!; public bool UsePolicies { get; set; } public bool UseSso { get; set; } public bool UseKeyConnector { get; set; } @@ -28,20 +25,22 @@ public class ProviderUserOrganizationDetails public bool SelfHost { get; set; } public bool UsersGetPremium { get; set; } public bool UseCustomPermissions { get; set; } + public bool UseSecretsManager { get; set; } + public bool UsePasswordManager { get; set; } public int? Seats { get; set; } public short? MaxCollections { get; set; } public short? MaxStorageGb { get; set; } - public string Key { get; set; } + public string? Key { get; set; } public ProviderUserStatusType Status { get; set; } public ProviderUserType Type { get; set; } public bool Enabled { get; set; } - public string Identifier { get; set; } - public string PublicKey { get; set; } - public string PrivateKey { get; set; } + public string? Identifier { get; set; } + public string? PublicKey { get; set; } + public string? PrivateKey { get; set; } public Guid? ProviderId { get; set; } public Guid? ProviderUserId { get; set; } [JsonConverter(typeof(HtmlEncodingStringConverter))] - public string ProviderName { get; set; } + public string? ProviderName { get; set; } public PlanType PlanType { get; set; } public bool LimitCollectionCreation { get; set; } public bool LimitCollectionDeletion { get; set; } @@ -50,5 +49,11 @@ public class ProviderUserOrganizationDetails public bool UseRiskInsights { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } - public ProviderType ProviderType { get; set; } + public ProviderType? ProviderType { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } + public bool? SsoEnabled { get; set; } + public string? SsoConfig { get; set; } + public string? SsoExternalId { get; set; } + public string? Permissions { get; set; } + public string? ResetPasswordKey { get; set; } } diff --git a/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs b/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs index ede2123f7e..3c811e2b28 100644 --- a/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs +++ b/src/Core/AdminConsole/Models/Slack/SlackApiResponse.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Text.Json.Serialization; +using System.Text.Json.Serialization; namespace Bit.Core.Models.Slack; @@ -35,6 +33,12 @@ public class SlackOAuthResponse : SlackApiResponse public SlackTeam Team { get; set; } = new(); } +public class SlackSendMessageResponse : SlackApiResponse +{ + [JsonPropertyName("channel")] + public string Channel { get; set; } = string.Empty; +} + public class SlackTeam { public string Id { get; set; } = string.Empty; diff --git a/src/Core/AdminConsole/Models/Teams/TeamsApiResponse.cs b/src/Core/AdminConsole/Models/Teams/TeamsApiResponse.cs new file mode 100644 index 0000000000..131e45264f --- /dev/null +++ b/src/Core/AdminConsole/Models/Teams/TeamsApiResponse.cs @@ -0,0 +1,41 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.Models.Teams; + +/// Represents the response returned by the Microsoft OAuth 2.0 token endpoint. +/// See Microsoft identity platform and OAuth 2.0 +/// authorization code flow. +public class TeamsOAuthResponse +{ + /// The access token issued by Microsoft, used to call the Microsoft Graph API. + [JsonPropertyName("access_token")] + public string AccessToken { get; set; } = string.Empty; +} + +/// Represents the response from the /me/joinedTeams Microsoft Graph API call. +/// See List joined teams - +/// Microsoft Graph v1.0. +public class JoinedTeamsResponse +{ + /// The collection of teams that the user has joined. + [JsonPropertyName("value")] + public List Value { get; set; } = []; +} + +/// Represents a Microsoft Teams team returned by the Graph API. +/// See Team resource type - +/// Microsoft Graph v1.0. +public class TeamInfo +{ + /// The unique identifier of the team. + [JsonPropertyName("id")] + public string Id { get; set; } = string.Empty; + + /// The name of the team. + [JsonPropertyName("displayName")] + public string DisplayName { get; set; } = string.Empty; + + /// The ID of the Microsoft Entra tenant for this team. + [JsonPropertyName("tenantId")] + public string TenantId { get; set; } = string.Empty; +} diff --git a/src/Core/AdminConsole/Models/Teams/TeamsBotCredentialProvider.cs b/src/Core/AdminConsole/Models/Teams/TeamsBotCredentialProvider.cs new file mode 100644 index 0000000000..eeb17131a3 --- /dev/null +++ b/src/Core/AdminConsole/Models/Teams/TeamsBotCredentialProvider.cs @@ -0,0 +1,28 @@ +using Microsoft.Bot.Connector.Authentication; + +namespace Bit.Core.AdminConsole.Models.Teams; + +public class TeamsBotCredentialProvider(string clientId, string clientSecret) : ICredentialProvider +{ + private const string _microsoftBotFrameworkIssuer = AuthenticationConstants.ToBotFromChannelTokenIssuer; + + public Task IsValidAppIdAsync(string appId) + { + return Task.FromResult(appId == clientId); + } + + public Task GetAppPasswordAsync(string appId) + { + return Task.FromResult(appId == clientId ? clientSecret : null); + } + + public Task IsAuthenticationDisabledAsync() + { + return Task.FromResult(false); + } + + public Task ValidateIssuerAsync(string issuer) + { + return Task.FromResult(issuer == _microsoftBotFrameworkIssuer); + } +} diff --git a/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs b/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs index af966a6e16..9c699a61cb 100644 --- a/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs +++ b/src/Core/AdminConsole/OrganizationAuth/UpdateOrganizationAuthRequestCommand.cs @@ -89,7 +89,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque AuthRequestExpiresAfter = _globalSettings.PasswordlessAuth.AdminRequestExpiration } ); - processor.Process((Exception e) => _logger.LogError(e.Message)); + processor.Process((Exception e) => _logger.LogError("Error processing organization auth request: {Message}", e.Message)); await processor.Save((IEnumerable authRequests) => _authRequestRepository.UpdateManyAsync(authRequests)); await processor.SendPushNotifications((ar) => _pushNotificationService.PushAuthRequestResponseAsync(ar)); await processor.SendApprovalEmailsForProcessedRequests(SendApprovalEmail); @@ -114,7 +114,7 @@ public class UpdateOrganizationAuthRequestCommand : IUpdateOrganizationAuthReque // This should be impossible if (user == null) { - _logger.LogError($"User {authRequest.UserId} not found. Trusted device admin approval email not sent."); + _logger.LogError("User {UserId} not found. Trusted device admin approval email not sent.", authRequest.UserId); return; } diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..5783301a0b --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs @@ -0,0 +1,79 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +public class AdminRecoverAccountCommand(IOrganizationRepository organizationRepository, + IPolicyRepository policyRepository, + IUserRepository userRepository, + IMailService mailService, + IEventService eventService, + IPushNotificationService pushNotificationService, + IUserService userService, + TimeProvider timeProvider) : IAdminRecoverAccountCommand +{ + public async Task RecoverAccountAsync(Guid orgId, + OrganizationUser organizationUser, string newMasterPassword, string key) + { + // Org must be able to use reset password + var org = await organizationRepository.GetByIdAsync(orgId); + if (org == null || !org.UseResetPassword) + { + throw new BadRequestException("Organization does not allow password reset."); + } + + // Enterprise policy must be enabled + var resetPasswordPolicy = + await policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + { + throw new BadRequestException("Organization does not have the password reset policy enabled."); + } + + // Org User must be confirmed and have a ResetPasswordKey + if (organizationUser == null || + organizationUser.Status != OrganizationUserStatusType.Confirmed || + organizationUser.OrganizationId != orgId || + string.IsNullOrEmpty(organizationUser.ResetPasswordKey) || + !organizationUser.UserId.HasValue) + { + throw new BadRequestException("Organization User not valid"); + } + + var user = await userService.GetUserByIdAsync(organizationUser.UserId.Value); + if (user == null) + { + throw new NotFoundException(); + } + + if (user.UsesKeyConnector) + { + throw new BadRequestException("Cannot reset password of a user with Key Connector."); + } + + var result = await userService.UpdatePasswordHash(user, newMasterPassword); + if (!result.Succeeded) + { + return result; + } + + user.RevisionDate = user.AccountRevisionDate = timeProvider.GetUtcNow().UtcDateTime; + user.LastPasswordChangeDate = user.RevisionDate; + user.ForcePasswordReset = true; + user.Key = key; + + await userRepository.ReplaceAsync(user); + await mailService.SendAdminResetPasswordEmailAsync(user.Email, user.Name, org.DisplayName()); + await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_AdminResetPassword); + await pushNotificationService.PushLogOutAsync(user.Id); + + return IdentityResult.Success; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs new file mode 100644 index 0000000000..75babc643e --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/IAdminRecoverAccountCommand.cs @@ -0,0 +1,24 @@ +using Bit.Core.Entities; +using Bit.Core.Exceptions; +using Microsoft.AspNetCore.Identity; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; + +/// +/// A command used to recover an organization user's account by an organization admin. +/// +public interface IAdminRecoverAccountCommand +{ + /// + /// Recovers an organization user's account by resetting their master password. + /// + /// The organization the user belongs to. + /// The organization user being recovered. + /// The user's new master password hash. + /// The user's new master-password-sealed user key. + /// An IdentityResult indicating success or failure. + /// When organization settings, policy, or user state is invalid. + /// When the user does not exist. + Task RecoverAccountAsync(Guid orgId, OrganizationUser organizationUser, + string newMasterPassword, string key); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index c03341bbc0..595e487580 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -24,7 +25,9 @@ public class VerifyOrganizationDomainCommand( IEventService eventService, IGlobalSettings globalSettings, ICurrentContext currentContext, + IFeatureService featureService, ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand, IMailService mailService, IOrganizationUserRepository organizationUserRepository, IOrganizationRepository organizationRepository, @@ -131,15 +134,26 @@ public class VerifyOrganizationDomainCommand( await SendVerifiedDomainUserEmailAsync(domain); } - private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) => - await savePolicyCommand.SaveAsync( - new PolicyUpdate - { - OrganizationId = organizationId, - Type = PolicyType.SingleOrg, - Enabled = true, - PerformedBy = actingUser - }); + private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) + { + var policyUpdate = new PolicyUpdate + { + OrganizationId = organizationId, + Type = PolicyType.SingleOrg, + Enabled = true, + PerformedBy = actingUser + }; + + if (featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) + { + var savePolicyModel = new SavePolicyModel(policyUpdate, actingUser); + await vNextSavePolicyCommand.SaveAsync(savePolicyModel); + } + else + { + await savePolicyCommand.SaveAsync(policyUpdate); + } + } private async Task SendVerifiedDomainUserEmailAsync(OrganizationDomain domain) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs index 6aef9f248b..d3df63b6ac 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyValidator.cs @@ -9,6 +9,10 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; /// /// Defines behavior and functionality for a given PolicyType. /// +/// +/// All methods defined in this interface are for the PolicyService#SavePolicy method. This needs to be supported until +/// we successfully refactor policy validators over to policy validation handlers +/// public interface IPolicyValidator { /// diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs index e846e02e46..c1450c6ab5 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyRequirementQuery.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; @@ -20,7 +18,7 @@ public class PolicyRequirementQuery( throw new NotImplementedException("No Requirement Factory found for " + typeof(T)); } - var policyDetails = await GetPolicyDetails(userId); + var policyDetails = await GetPolicyDetails(userId, factory.PolicyType); var filteredPolicies = policyDetails .Where(p => p.PolicyType == factory.PolicyType) .Where(factory.Enforce); @@ -48,8 +46,8 @@ public class PolicyRequirementQuery( return eligibleOrganizationUserIds; } - private Task> GetPolicyDetails(Guid userId) - => policyRepository.GetPolicyDetailsByUserId(userId); + private async Task> GetPolicyDetails(Guid userId, PolicyType policyType) + => await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType); private async Task> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType) => await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs new file mode 100644 index 0000000000..5d40cb211f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/VNextSavePolicyCommand.cs @@ -0,0 +1,195 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; + +public class VNextSavePolicyCommand( + IApplicationCacheService applicationCacheService, + IEventService eventService, + IPolicyRepository policyRepository, + IEnumerable policyUpdateEventHandlers, + TimeProvider timeProvider, + IPolicyEventHandlerFactory policyEventHandlerFactory) + : IVNextSavePolicyCommand +{ + + public async Task SaveAsync(SavePolicyModel policyRequest) + { + var policyUpdateRequest = policyRequest.PolicyUpdate; + var organizationId = policyUpdateRequest.OrganizationId; + + await EnsureOrganizationCanUsePolicyAsync(organizationId); + + var savedPoliciesDict = await GetCurrentPolicyStateAsync(organizationId); + + var currentPolicy = savedPoliciesDict.GetValueOrDefault(policyUpdateRequest.Type); + + ValidatePolicyDependencies(policyUpdateRequest, currentPolicy, savedPoliciesDict); + + await ValidateTargetedPolicyAsync(policyRequest, currentPolicy); + + await ExecutePreUpsertSideEffectAsync(policyRequest, currentPolicy); + + var upsertedPolicy = await UpsertPolicyAsync(policyUpdateRequest); + + await eventService.LogPolicyEventAsync(upsertedPolicy, EventType.Policy_Updated); + + await ExecutePostUpsertSideEffectAsync(policyRequest, upsertedPolicy, currentPolicy); + + return upsertedPolicy; + } + + private async Task EnsureOrganizationCanUsePolicyAsync(Guid organizationId) + { + var org = await applicationCacheService.GetOrganizationAbilityAsync(organizationId); + if (org == null) + { + throw new BadRequestException("Organization not found"); + } + + if (!org.UsePolicies) + { + throw new BadRequestException("This organization cannot use policies."); + } + } + + private async Task UpsertPolicyAsync(PolicyUpdate policyUpdateRequest) + { + var policy = await policyRepository.GetByOrganizationIdTypeAsync(policyUpdateRequest.OrganizationId, policyUpdateRequest.Type) + ?? new Policy + { + OrganizationId = policyUpdateRequest.OrganizationId, + Type = policyUpdateRequest.Type, + CreationDate = timeProvider.GetUtcNow().UtcDateTime + }; + + policy.Enabled = policyUpdateRequest.Enabled; + policy.Data = policyUpdateRequest.Data; + policy.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; + + await policyRepository.UpsertAsync(policy); + + return policy; + } + + private async Task ValidateTargetedPolicyAsync(SavePolicyModel policyRequest, + Policy? currentPolicy) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + async validator => + { + var validationError = await validator.ValidateAsync(policyRequest, currentPolicy); + if (!string.IsNullOrEmpty(validationError)) + { + throw new BadRequestException(validationError); + } + }); + } + + private void ValidatePolicyDependencies( + PolicyUpdate policyUpdateRequest, + Policy? currentPolicy, + Dictionary savedPoliciesDict) + { + var isCurrentlyEnabled = currentPolicy?.Enabled == true; + var isBeingEnabled = policyUpdateRequest.Enabled && !isCurrentlyEnabled; + var isBeingDisabled = !policyUpdateRequest.Enabled && isCurrentlyEnabled; + + if (isBeingEnabled) + { + ValidateEnablingRequirements(policyUpdateRequest.Type, savedPoliciesDict); + } + else if (isBeingDisabled) + { + ValidateDisablingRequirements(policyUpdateRequest.Type, savedPoliciesDict); + } + } + + private void ValidateDisablingRequirements( + PolicyType policyType, + Dictionary savedPoliciesDict) + { + var dependentPolicyTypes = policyUpdateEventHandlers + .OfType() + .Where(otherValidator => otherValidator.RequiredPolicies.Contains(policyType)) + .Select(otherValidator => otherValidator.Type) + .Where(otherPolicyType => savedPoliciesDict.TryGetValue(otherPolicyType, out var savedPolicy) && + savedPolicy.Enabled) + .ToList(); + + switch (dependentPolicyTypes) + { + case { Count: 1 }: + throw new BadRequestException($"Turn off the {dependentPolicyTypes.First().GetName()} policy because it requires the {policyType.GetName()} policy."); + case { Count: > 1 }: + throw new BadRequestException($"Turn off all of the policies that require the {policyType.GetName()} policy."); + } + } + + private void ValidateEnablingRequirements( + PolicyType policyType, + Dictionary savedPoliciesDict) + { + var result = policyEventHandlerFactory.GetHandler(policyType); + + result.Switch( + validator => + { + var missingRequiredPolicyTypes = validator.RequiredPolicies + .Where(requiredPolicyType => savedPoliciesDict.GetValueOrDefault(requiredPolicyType) is not { Enabled: true }) + .ToList(); + + if (missingRequiredPolicyTypes.Count != 0) + { + throw new BadRequestException($"Turn on the {missingRequiredPolicyTypes.First().GetName()} policy because it is required for the {policyType.GetName()} policy."); + } + }, + _ => { /* Policy has no required dependencies */ }); + } + + private async Task ExecutePreUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + handler => handler.ExecutePreUpsertSideEffectAsync(policyRequest, currentPolicy)); + } + private async Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState) + { + await ExecutePolicyEventAsync( + policyRequest.PolicyUpdate.Type, + handler => handler.ExecutePostUpsertSideEffectAsync( + policyRequest, + postUpsertedPolicyState, + previousPolicyState)); + } + + private async Task ExecutePolicyEventAsync(PolicyType type, Func func) where T : IPolicyUpdateEvent + { + var handler = policyEventHandlerFactory.GetHandler(type); + + await handler.Match( + async h => await func(h), + _ => Task.CompletedTask + ); + } + + private async Task> GetCurrentPolicyStateAsync(Guid organizationId) + { + var savedPolicies = await policyRepository.GetManyByOrganizationIdAsync(organizationId); + // Note: policies may be missing from this dict if they have never been enabled + var savedPoliciesDict = savedPolicies.ToDictionary(p => p.Type); + return savedPoliciesDict; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs index d1a52f0080..cad786234c 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs @@ -16,6 +16,8 @@ public record PolicyUpdate public PolicyType Type { get; set; } public string? Data { get; set; } public bool Enabled { get; set; } + + [Obsolete("Please use SavePolicyModel.PerformedBy instead.")] public IActingUser? PerformedBy { get; set; } public T GetDataModel() where T : IPolicyDataModel, new() diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs index 7c8d5126e8..01168deea4 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/SavePolicyModel.cs @@ -5,4 +5,18 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; public record SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser? PerformedBy, IPolicyMetadataModel Metadata) { + public SavePolicyModel(PolicyUpdate PolicyUpdate) + : this(PolicyUpdate, null, new EmptyMetadataModel()) + { + } + + public SavePolicyModel(PolicyUpdate PolicyUpdate, IActingUser performedBy) + : this(PolicyUpdate, performedBy, new EmptyMetadataModel()) + { + } + + public SavePolicyModel(PolicyUpdate PolicyUpdate, IPolicyMetadataModel metadata) + : this(PolicyUpdate, null, metadata) + { + } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index 5433d70410..7c1987865a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -1,5 +1,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; @@ -13,13 +15,17 @@ public static class PolicyServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddPolicyValidators(); services.AddPolicyRequirements(); services.AddPolicySideEffects(); + services.AddPolicyUpdateEvents(); } + [Obsolete("Use AddPolicyUpdateEvents instead.")] private static void AddPolicyValidators(this IServiceCollection services) { services.AddScoped(); @@ -27,14 +33,29 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); } + [Obsolete("Use AddPolicyUpdateEvents instead.")] private static void AddPolicySideEffects(this IServiceCollection services) { services.AddScoped(); } + private static void AddPolicyUpdateEvents(this IServiceCollection services) + { + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + } + private static void AddPolicyRequirements(this IServiceCollection services) { services.AddScoped, DisableSendPolicyRequirementFactory>(); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs new file mode 100644 index 0000000000..0e2bdc3d69 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IEnforceDependentPoliciesEvent.cs @@ -0,0 +1,19 @@ +using Bit.Core.AdminConsole.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all policies required to be enabled before the given policy can be enabled. +/// +/// +/// This interface is intended for policy event handlers that mandate the activation of other policies +/// as prerequisites for enabling the associated policy. +/// +public interface IEnforceDependentPoliciesEvent : IPolicyUpdateEvent +{ + /// + /// PolicyTypes that must be enabled before this policy can be enabled, if any. + /// These dependencies will be checked when this policy is enabled and when any required policy is disabled. + /// + public IEnumerable RequiredPolicies { get; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs new file mode 100644 index 0000000000..08295bf7fb --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPostUpdateEvent.cs @@ -0,0 +1,18 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +public interface IOnPolicyPostUpdateEvent : IPolicyUpdateEvent +{ + /// + /// Performs side effects after a policy has been upserted. + /// For example, this can be used for cleanup tasks or notifications. + /// + /// The policy save request + /// The policy after it was upserted + /// The policy state before it was updated, if any + public Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs new file mode 100644 index 0000000000..4167a392e4 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IOnPolicyPreUpdateEvent.cs @@ -0,0 +1,23 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all side effects that should be executed before a policy is upserted. +/// +/// +/// This should be added to policy handlers that need to perform side effects before policy upserts. +/// +public interface IOnPolicyPreUpdateEvent : IPolicyUpdateEvent +{ + /// + /// Performs side effects before a policy is upserted. + /// For example, this can be used to remove non-compliant users from the organization. + /// + /// The policy save request containing the policy update and metadata + /// The current policy, if any + public Task ExecutePreUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs new file mode 100644 index 0000000000..f44ae867dd --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyEventHandlerFactory.cs @@ -0,0 +1,30 @@ +#nullable enable + +using Bit.Core.AdminConsole.Enums; +using OneOf; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Provides policy-specific event handlers used during the save workflow in . +/// +/// +/// Supported handlers: +/// - for dependency checks +/// - for custom validation +/// - for pre-save logic +/// - for post-save logic +/// +public interface IPolicyEventHandlerFactory +{ + /// + /// Gets the event handler for the given policy type and handler interface. + /// + /// Handler type implementing . + /// The policy type to resolve. + /// + /// — the handler if available, or None if not implemented. + /// + OneOf GetHandler(PolicyType policyType) where T : IPolicyUpdateEvent; +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs new file mode 100644 index 0000000000..a568658d4d --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyUpdateEvent.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Enums; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents the policy to be upserted. +/// +/// +/// This is used for the VNextSavePolicyCommand. All policy handlers should implement this interface. +/// +public interface IPolicyUpdateEvent +{ + /// + /// The policy type that the associated handler will handle. + /// + public PolicyType Type { get; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs new file mode 100644 index 0000000000..ee401ef813 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IPolicyValidationEvent.cs @@ -0,0 +1,24 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Represents all validations that need to be run to enable or disable the given policy. +/// +/// +/// This is used for the VNextSavePolicyCommand. This optional but should be implemented for all policies that have +/// certain requirements for the given organization. +/// +public interface IPolicyValidationEvent : IPolicyUpdateEvent +{ + /// + /// Performs any validations required to enable or disable the policy. + /// + /// The policy save request containing the policy update and metadata + /// The current policy, if any + public Task ValidateAsync( + SavePolicyModel policyRequest, + Policy? currentPolicy); + +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs new file mode 100644 index 0000000000..93414539bb --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/Interfaces/IVNextSavePolicyCommand.cs @@ -0,0 +1,34 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Microsoft.Azure.NotificationHubs.Messaging; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +/// +/// Handles creating or updating organization policies with validation and side effect execution. +/// +/// +/// Workflow: +/// 1. Validates organization can use policies +/// 2. Validates required and dependent policies +/// 3. Runs policy-specific validation () +/// 4. Executes pre-save logic () +/// 5. Saves the policy +/// 6. Logs the event +/// 7. Executes post-save logic () +/// +public interface IVNextSavePolicyCommand +{ + /// + /// Performs the necessary validations, saves the policy and any side effects + /// + /// Policy data, acting user, and metadata. + /// The saved policy with updated revision and applied changes. + /// + /// Thrown if: + /// - The organization can’t use policies + /// - Dependent policies are missing or block changes + /// - Custom validation fails + /// + Task SaveAsync(SavePolicyModel policyRequest); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs new file mode 100644 index 0000000000..b1abfb2aaf --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyUpdateEvents/PolicyEventHandlerHandlerFactory.cs @@ -0,0 +1,33 @@ + +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using OneOf; +using OneOf.Types; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents; + +public class PolicyEventHandlerHandlerFactory( + IEnumerable allEventHandlers) : IPolicyEventHandlerFactory +{ + public OneOf GetHandler(PolicyType policyType) where T : IPolicyUpdateEvent + { + var tEventHandlers = allEventHandlers.OfType().ToList(); + + var matchingHandlers = tEventHandlers.Where(h => h.Type == policyType).ToList(); + + if (matchingHandlers.Count > 1) + { + throw new InvalidOperationException( + $"Multiple {nameof(IPolicyUpdateEvent)} handlers of type {typeof(T).Name} found for {nameof(PolicyType)} {policyType}. " + + $"Expected one {typeof(T).Name} handler per {nameof(PolicyType)}."); + } + + var policyTEventHandler = matchingHandlers.SingleOrDefault(); + if (policyTEventHandler is null) + { + return new None(); + } + + return policyTEventHandler; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs new file mode 100644 index 0000000000..c0d302df02 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/AutomaticUserConfirmationPolicyEventHandler.cs @@ -0,0 +1,131 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Enums; +using Bit.Core.Repositories; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +/// +/// Represents an event handler for the Automatic User Confirmation policy. +/// +/// This class validates that the following conditions are met: +///
    +///
  • The Single organization policy is enabled
  • +///
  • All organization users are compliant with the Single organization policy
  • +///
  • No provider users exist
  • +///
+/// +/// This class also performs side effects when the policy is being enabled or disabled. They are: +///
    +///
  • Sets the UseAutomaticUserConfirmation organization feature to match the policy update
  • +///
+///
+public class AutomaticUserConfirmationPolicyEventHandler( + IOrganizationUserRepository organizationUserRepository, + IProviderUserRepository providerUserRepository, + IPolicyRepository policyRepository, + IOrganizationRepository organizationRepository, + TimeProvider timeProvider) + : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent, IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.AutomaticUserConfirmation; + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) => + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + + private const string _singleOrgPolicyNotEnabledErrorMessage = + "The Single organization policy must be enabled before enabling the Automatically confirm invited users policy."; + + private const string _usersNotCompliantWithSingleOrgErrorMessage = + "All organization users must be compliant with the Single organization policy before enabling the Automatically confirm invited users policy. Please remove users who are members of multiple organizations."; + + private const string _providerUsersExistErrorMessage = + "The organization has users with the Provider user type. Please remove provider users before enabling the Automatically confirm invited users policy."; + + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + var isNotEnablingPolicy = policyUpdate is not { Enabled: true }; + var policyAlreadyEnabled = currentPolicy is { Enabled: true }; + if (isNotEnablingPolicy || policyAlreadyEnabled) + { + return string.Empty; + } + + return await ValidateEnablingPolicyAsync(policyUpdate.OrganizationId); + } + + public async Task ValidateAsync(SavePolicyModel savePolicyModel, Policy? currentPolicy) => + await ValidateAsync(savePolicyModel.PolicyUpdate, currentPolicy); + + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) + { + var organization = await organizationRepository.GetByIdAsync(policyUpdate.OrganizationId); + + if (organization is not null) + { + organization.UseAutomaticUserConfirmation = policyUpdate.Enabled; + organization.RevisionDate = timeProvider.GetUtcNow().UtcDateTime; + await organizationRepository.UpsertAsync(organization); + } + } + + private async Task ValidateEnablingPolicyAsync(Guid organizationId) + { + var singleOrgValidationError = await ValidateSingleOrgPolicyComplianceAsync(organizationId); + if (!string.IsNullOrWhiteSpace(singleOrgValidationError)) + { + return singleOrgValidationError; + } + + var providerValidationError = await ValidateNoProviderUsersAsync(organizationId); + if (!string.IsNullOrWhiteSpace(providerValidationError)) + { + return providerValidationError; + } + + return string.Empty; + } + + private async Task ValidateSingleOrgPolicyComplianceAsync(Guid organizationId) + { + var singleOrgPolicy = await policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.SingleOrg); + if (singleOrgPolicy is not { Enabled: true }) + { + return _singleOrgPolicyNotEnabledErrorMessage; + } + + return await ValidateUserComplianceWithSingleOrgAsync(organizationId); + } + + private async Task ValidateUserComplianceWithSingleOrgAsync(Guid organizationId) + { + var organizationUsers = (await organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId)) + .Where(ou => ou.Status != OrganizationUserStatusType.Invited && + ou.Status != OrganizationUserStatusType.Revoked && + ou.UserId.HasValue) + .ToList(); + + if (organizationUsers.Count == 0) + { + return string.Empty; + } + + var hasNonCompliantUser = (await organizationUserRepository.GetManyByManyUsersAsync( + organizationUsers.Select(ou => ou.UserId!.Value))) + .Any(uo => uo.OrganizationId != organizationId && + uo.Status != OrganizationUserStatusType.Invited); + + return hasNonCompliantUser ? _usersNotCompliantWithSingleOrgErrorMessage : string.Empty; + } + + private async Task ValidateNoProviderUsersAsync(Guid organizationId) + { + var providerUsers = await providerUserRepository.GetManyByOrganizationAsync(organizationId); + + return providerUsers.Count > 0 ? _providerUsersExistErrorMessage : string.Empty; + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs index 57db4962e3..52a7e3e880 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/FreeFamiliesForEnterprisePolicyValidator.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; @@ -12,11 +13,16 @@ public class FreeFamiliesForEnterprisePolicyValidator( IOrganizationSponsorshipRepository organizationSponsorshipRepository, IMailService mailService, IOrganizationRepository organizationRepository) - : IPolicyValidator + : IPolicyValidator, IOnPolicyPreUpdateEvent { public PolicyType Type => PolicyType.FreeFamiliesSponsorshipPolicy; public IEnumerable RequiredPolicies => []; + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs index bfd4dcfe0d..796ed286d8 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/MaximumVaultTimeoutPolicyValidator.cs @@ -3,10 +3,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator +public class MaximumVaultTimeoutPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent { public PolicyType Type => PolicyType.MaximumVaultTimeout; public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index f4ef6021a7..0bee2a55af 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -1,24 +1,32 @@  using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Repositories; using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -/// -/// Please do not extend or expand this validator. We're currently in the process of refactoring our policy validator pattern. -/// This is a stop-gap solution for post-policy-save side effects, but it is not the long-term solution. -/// public class OrganizationDataOwnershipPolicyValidator( IPolicyRepository policyRepository, ICollectionRepository collectionRepository, IEnumerable> factories, IFeatureService featureService) - : OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect + : OrganizationPolicyValidator(policyRepository, factories), IPostSavePolicySideEffect, IOnPolicyPostUpdateEvent { + public PolicyType Type => PolicyType.OrganizationDataOwnership; + + public async Task ExecutePostUpsertSideEffectAsync( + SavePolicyModel policyRequest, + Policy postUpsertedPolicyState, + Policy? previousPolicyState) + { + await ExecuteSideEffectsAsync(policyRequest, postUpsertedPolicyState, previousPolicyState); + } + public async Task ExecuteSideEffectsAsync( SavePolicyModel policyRequest, Policy postUpdatedPolicy, @@ -68,5 +76,4 @@ public class OrganizationDataOwnershipPolicyValidator( userOrgIds, defaultCollectionName); } - } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs index 2082d4305f..adc2a3865a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/RequireSsoPolicyValidator.cs @@ -3,12 +3,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class RequireSsoPolicyValidator : IPolicyValidator +public class RequireSsoPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent { private readonly ISsoConfigRepository _ssoConfigRepository; @@ -20,6 +21,11 @@ public class RequireSsoPolicyValidator : IPolicyValidator public PolicyType Type => PolicyType.RequireSso; public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (policyUpdate is not { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs index 1126c4b922..9033a38ad0 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/ResetPasswordPolicyValidator.cs @@ -4,12 +4,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class ResetPasswordPolicyValidator : IPolicyValidator +public class ResetPasswordPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IEnforceDependentPoliciesEvent { private readonly ISsoConfigRepository _ssoConfigRepository; public PolicyType Type => PolicyType.ResetPassword; @@ -20,6 +21,11 @@ public class ResetPasswordPolicyValidator : IPolicyValidator _ssoConfigRepository = ssoConfigRepository; } + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (policyUpdate is not { Enabled: true } || diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs index 49467eaae4..c0378bf5f9 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Context; @@ -17,7 +18,7 @@ using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class SingleOrgPolicyValidator : IPolicyValidator +public class SingleOrgPolicyValidator : IPolicyValidator, IPolicyValidationEvent, IOnPolicyPreUpdateEvent { public PolicyType Type => PolicyType.SingleOrg; private const string OrganizationNotFoundErrorMessage = "Organization not found."; @@ -57,6 +58,16 @@ public class SingleOrgPolicyValidator : IPolicyValidator public IEnumerable RequiredPolicies => []; + public async Task ValidateAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + return await ValidateAsync(policyRequest.PolicyUpdate, currentPolicy); + } + + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs index 5ce72df6c1..7f3ebcccfb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Context; using Bit.Core.Enums; @@ -16,7 +17,7 @@ using Bit.Core.Services; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; -public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator +public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator, IOnPolicyPreUpdateEvent { private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IMailService _mailService; @@ -46,6 +47,11 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } + public async Task ExecutePreUpsertSideEffectAsync(SavePolicyModel policyRequest, Policy? currentPolicy) + { + await OnSaveSideEffectsAsync(policyRequest.PolicyUpdate, currentPolicy); + } + public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs new file mode 100644 index 0000000000..5bffd944c9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/UriMatchDefaultPolicyValidator.cs @@ -0,0 +1,14 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; + +public class UriMatchDefaultPolicyValidator : IPolicyValidator, IEnforceDependentPoliciesEvent +{ + public PolicyType Type => PolicyType.UriMatchDefaults; + public IEnumerable RequiredPolicies => [PolicyType.SingleOrg]; + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); + public Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.CompletedTask; +} diff --git a/src/Core/AdminConsole/Repositories/IEventRepository.cs b/src/Core/AdminConsole/Repositories/IEventRepository.cs index 281d6ec8c7..f0c185561b 100644 --- a/src/Core/AdminConsole/Repositories/IEventRepository.cs +++ b/src/Core/AdminConsole/Repositories/IEventRepository.cs @@ -27,6 +27,7 @@ public interface IEventRepository DateTime startDate, DateTime endDate, PageOptions pageOptions); Task> GetManyByCipherAsync(Cipher cipher, DateTime startDate, DateTime endDate, PageOptions pageOptions); + Task CreateAsync(IEvent e); Task CreateManyAsync(IEnumerable e); Task> GetManyByOrganizationServiceAccountAsync(Guid organizationId, Guid serviceAccountId, diff --git a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs index 434c8ddee3..1d8b8be0ec 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationIntegrationRepository.cs @@ -5,4 +5,6 @@ namespace Bit.Core.Repositories; public interface IOrganizationIntegrationRepository : IRepository { Task> GetManyByOrganizationAsync(Guid organizationId); + + Task GetByTeamsConfigurationTenantIdTeamId(string tenantId, string teamId); } diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index 37a830c92e..b17de3c51d 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -87,4 +87,13 @@ public interface IOrganizationUserRepository : IRepository> GetManyDetailsByRoleAsync(Guid organizationId, OrganizationUserType role); Task CreateManyAsync(IEnumerable organizationUserCollection); + + /// + /// It will only confirm if the user is in the `Accepted` state. + /// + /// This is an idempotent operation. + /// + /// Accepted OrganizationUser to confirm + /// True, if the user was updated. False, if not performed. + Task ConfirmOrganizationUserAsync(OrganizationUser organizationUser); } diff --git a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs index 9f5c7f3fc4..d479809b89 100644 --- a/src/Core/AdminConsole/Repositories/IPolicyRepository.cs +++ b/src/Core/AdminConsole/Repositories/IPolicyRepository.cs @@ -20,17 +20,6 @@ public interface IPolicyRepository : IRepository Task GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type); Task> GetManyByOrganizationIdAsync(Guid organizationId); Task> GetManyByUserIdAsync(Guid userId); - /// - /// Gets all PolicyDetails for a user for all policy types. - /// - /// - /// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced - /// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan - /// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement. - /// This is consumed by to create requirements for specific policy types. - /// You probably do not want to call it directly. - /// - Task> GetPolicyDetailsByUserId(Guid userId); /// /// Retrieves of the specified diff --git a/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs b/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs index c9c803b5b2..169b36bf69 100644 --- a/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs +++ b/src/Core/AdminConsole/Repositories/TableStorage/EventRepository.cs @@ -77,12 +77,18 @@ public class EventRepository : IEventRepository return await GetManyAsync(partitionKey, $"CipherId={cipher.Id}__Date={{0}}", startDate, endDate, pageOptions); } - public async Task> GetManyByOrganizationServiceAccountAsync(Guid organizationId, - Guid serviceAccountId, DateTime startDate, DateTime endDate, PageOptions pageOptions) + public async Task> GetManyByOrganizationServiceAccountAsync( + Guid organizationId, + Guid serviceAccountId, + DateTime startDate, + DateTime endDate, + PageOptions pageOptions) { + return await GetManyServiceAccountAsync( + $"OrganizationId={organizationId}", + serviceAccountId.ToString(), + startDate, endDate, pageOptions); - return await GetManyAsync($"OrganizationId={organizationId}", - $"ServiceAccountId={serviceAccountId}__Date={{0}}", startDate, endDate, pageOptions); } public async Task CreateAsync(IEvent e) @@ -141,6 +147,40 @@ public class EventRepository : IEventRepository } } + public async Task> GetManyServiceAccountAsync( + string partitionKey, + string serviceAccountId, + DateTime startDate, + DateTime endDate, + PageOptions pageOptions) + { + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + var filter = MakeFilterForServiceAccount(partitionKey, serviceAccountId, startDate, endDate); + + var result = new PagedResult(); + var query = _tableClient.QueryAsync(filter, pageOptions.PageSize); + + await using (var enumerator = query.AsPages(pageOptions.ContinuationToken, + pageOptions.PageSize).GetAsyncEnumerator()) + { + if (await enumerator.MoveNextAsync()) + { + result.ContinuationToken = enumerator.Current.ContinuationToken; + + var events = enumerator.Current.Values + .Select(e => e.ToEventTableEntity()) + .ToList(); + + events = events.OrderByDescending(e => e.Date).ToList(); + + result.Data.AddRange(events); + } + } + + return result; + } + public async Task> GetManyAsync(string partitionKey, string rowKey, DateTime startDate, DateTime endDate, PageOptions pageOptions) { @@ -172,4 +212,27 @@ public class EventRepository : IEventRepository { return $"PartitionKey eq '{partitionKey}' and RowKey le '{rowStart}' and RowKey ge '{rowEnd}'"; } + + private string MakeFilterForServiceAccount( + string partitionKey, + string machineAccountId, + DateTime startDate, + DateTime endDate) + { + var start = CoreHelpers.DateTimeToTableStorageKey(startDate); + var end = CoreHelpers.DateTimeToTableStorageKey(endDate); + + var rowKey1Start = $"ServiceAccountId={machineAccountId}__Date={start}"; + var rowKey1End = $"ServiceAccountId={machineAccountId}__Date={end}"; + + var rowKey2Start = $"GrantedServiceAccountId={machineAccountId}__Date={start}"; + var rowKey2End = $"GrantedServiceAccountId={machineAccountId}__Date={end}"; + + var left = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey1Start}' and RowKey ge '{rowKey1End}'"; + var right = $"PartitionKey eq '{partitionKey}' and RowKey le '{rowKey2Start}' and RowKey ge '{rowKey2End}'"; + + return $"({left}) or ({right})"; + } + + } diff --git a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs b/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs index b80b518223..4d95707e90 100644 --- a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs +++ b/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs @@ -5,5 +5,5 @@ namespace Bit.Core.Services; public interface IEventIntegrationPublisher : IAsyncDisposable { Task PublishAsync(IIntegrationMessage message); - Task PublishEventAsync(string body); + Task PublishEventAsync(string body, string? organizationId); } diff --git a/src/Core/AdminConsole/Services/IEventService.cs b/src/Core/AdminConsole/Services/IEventService.cs index 80e8e63d8c..795c06e254 100644 --- a/src/Core/AdminConsole/Services/IEventService.cs +++ b/src/Core/AdminConsole/Services/IEventService.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; +using Bit.Core.Auth.Identity; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; @@ -37,4 +38,7 @@ public interface IEventService Task LogServiceAccountSecretsEventAsync(Guid serviceAccountId, IEnumerable secrets, EventType type, DateTime? date = null); Task LogUserProjectsEventAsync(Guid userId, IEnumerable projects, EventType type, DateTime? date = null); Task LogServiceAccountProjectsEventAsync(Guid serviceAccountId, IEnumerable projects, EventType type, DateTime? date = null); + Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null); + Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null); + Task LogServiceAccountEventAsync(Guid userId, List serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null); } diff --git a/src/Core/AdminConsole/Services/IProviderService.cs b/src/Core/AdminConsole/Services/IProviderService.cs index 66c49d90c6..2b954346ae 100644 --- a/src/Core/AdminConsole/Services/IProviderService.cs +++ b/src/Core/AdminConsole/Services/IProviderService.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -11,8 +11,7 @@ namespace Bit.Core.AdminConsole.Services; public interface IProviderService { - Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource = null); + Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress); Task UpdateAsync(Provider provider, bool updateBilling = false); Task> InviteUserAsync(ProviderUserInvite invite); diff --git a/src/Core/AdminConsole/Services/ISlackService.cs b/src/Core/AdminConsole/Services/ISlackService.cs index 6c6a846f0d..60d3da8af4 100644 --- a/src/Core/AdminConsole/Services/ISlackService.cs +++ b/src/Core/AdminConsole/Services/ISlackService.cs @@ -1,11 +1,61 @@ -namespace Bit.Core.Services; +using Bit.Core.Models.Slack; +namespace Bit.Core.Services; + +/// Defines operations for interacting with Slack, including OAuth authentication, channel discovery, +/// and sending messages. public interface ISlackService { + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up channels for a user. + /// Retrieves the ID of a Slack channel by name. + /// See conversations.list API. + /// A valid Slack OAuth access token. + /// The name of the channel to look up. + /// The channel ID if found; otherwise, an empty string. Task GetChannelIdAsync(string token, string channelName); + + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up channels for a user. + /// Retrieves the IDs of multiple Slack channels by name. + /// See conversations.list API. + /// A valid Slack OAuth access token. + /// A list of channel names to look up. + /// A list of matching channel IDs. Channels that cannot be found are omitted. Task> GetChannelIdsAsync(string token, List channelNames); + + /// Note: This API is not currently used (yet) by any server code. It is here to provide functionality if + /// the UI needs to be able to look up a user by their email address. + /// Retrieves the DM channel ID for a Slack user by email. + /// See users.lookupByEmail API and + /// conversations.open API. + /// A valid Slack OAuth access token. + /// The email address of the user to open a DM with. + /// The DM channel ID if successful; otherwise, an empty string. Task GetDmChannelByEmailAsync(string token, string email); - string GetRedirectUrl(string redirectUrl); + + /// Builds the Slack OAuth 2.0 authorization URL for the app. + /// See Slack OAuth v2 documentation. + /// The absolute redirect URI that Slack will call after user authorization. + /// Must match the URI registered with the app configuration. + /// A state token used to correlate the request and callback and prevent CSRF attacks. + /// The full authorization URL to which the user should be redirected to begin the sign-in process. + string GetRedirectUrl(string callbackUrl, string state); + + /// Exchanges a Slack OAuth code for an access token. + /// See oauth.v2.access API. + /// The authorization code returned by Slack via the callback URL after user authorization. + /// The redirect URI that was used in the authorization request. + /// A valid Slack access token if successful; otherwise, an empty string. Task ObtainTokenViaOAuth(string code, string redirectUrl); - Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId); + + /// Sends a message to a Slack channel by ID. + /// See chat.postMessage API. + /// This is used primarily by the to send events to the + /// Slack channel. + /// A valid Slack OAuth access token. + /// The message text to send. + /// The channel ID to send the message to. + /// The response from Slack after sending the message. + Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId); } diff --git a/src/Core/AdminConsole/Services/ITeamsService.cs b/src/Core/AdminConsole/Services/ITeamsService.cs new file mode 100644 index 0000000000..e3757987c3 --- /dev/null +++ b/src/Core/AdminConsole/Services/ITeamsService.cs @@ -0,0 +1,49 @@ +using Bit.Core.Models.Teams; + +namespace Bit.Core.Services; + +/// +/// Service that provides functionality relating to the Microsoft Teams integration including OAuth, +/// team discovery and sending a message to a channel in Teams. +/// +public interface ITeamsService +{ + /// + /// Generate the Microsoft Teams OAuth 2.0 authorization URL used to begin the sign-in flow. + /// + /// The absolute redirect URI that Microsoft will call after user authorization. + /// Must match the URI registered with the app configuration. + /// A state token used to correlate the request and callback and prevent CSRF attacks. + /// The full authorization URL to which the user should be redirected to begin the sign-in process. + string GetRedirectUrl(string callbackUrl, string state); + + /// + /// Exchange the OAuth code for a Microsoft Graph API access token. + /// + /// The code returned from Microsoft via the OAuth callback Url. + /// The same redirect URI that was passed to the authorization request. + /// A valid Microsoft Graph access token if the exchange succeeds; otherwise, an empty string. + Task ObtainTokenViaOAuth(string code, string redirectUrl); + + /// + /// Get the Teams to which the authenticated user belongs via Microsoft Graph API. + /// + /// A valid Microsoft Graph access token for the user (obtained via OAuth). + /// A read-only list of objects representing the user’s joined teams. + /// Returns an empty list if the request fails or if the token is invalid. + Task> GetJoinedTeamsAsync(string accessToken); + + /// + /// Send a message to a specific channel in Teams. + /// + /// This is used primarily by the to send events to the + /// Teams channel. + /// The service URI associated with the Microsoft Bot Framework connector for the target + /// team. Obtained via the bot framework callback. + /// The conversation or channel ID where the message should be delivered. Obtained via + /// the bot framework callback. + /// The message text to post to the channel. + /// A task that completes when the message has been sent. Errors during message delivery are surfaced + /// as exceptions from the underlying connector client. + Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message); +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs index 91f8fac888..a589211687 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusEventListenerService.cs @@ -14,13 +14,14 @@ public class AzureServiceBusEventListenerService : EventLoggingL TConfiguration configuration, IEventMessageHandler handler, IAzureServiceBusService serviceBusService, + ServiceBusProcessorOptions serviceBusOptions, ILoggerFactory loggerFactory) : base(handler, CreateLogger(loggerFactory, configuration)) { _processor = serviceBusService.CreateProcessor( topicName: configuration.EventTopicName, subscriptionName: configuration.EventSubscriptionName, - new ServiceBusProcessorOptions()); + options: serviceBusOptions); } protected override async Task ExecuteAsync(CancellationToken cancellationToken) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs index e415430965..633a53296b 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusIntegrationListenerService.cs @@ -18,6 +18,7 @@ public class AzureServiceBusIntegrationListenerService : Backgro TConfiguration configuration, IIntegrationHandler handler, IAzureServiceBusService serviceBusService, + ServiceBusProcessorOptions serviceBusOptions, ILoggerFactory loggerFactory) { _handler = handler; @@ -29,7 +30,7 @@ public class AzureServiceBusIntegrationListenerService : Backgro _processor = _serviceBusService.CreateProcessor( topicName: configuration.IntegrationTopicName, subscriptionName: configuration.IntegrationSubscriptionName, - options: new ServiceBusProcessorOptions()); + options: serviceBusOptions); } protected override async Task ExecuteAsync(CancellationToken cancellationToken) diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs index 4887aa3a7f..953a9bb56e 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs @@ -30,7 +30,8 @@ public class AzureServiceBusService : IAzureServiceBusService var serviceBusMessage = new ServiceBusMessage(json) { Subject = message.IntegrationType.ToRoutingKey(), - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); @@ -44,18 +45,20 @@ public class AzureServiceBusService : IAzureServiceBusService { Subject = message.IntegrationType.ToRoutingKey(), ScheduledEnqueueTime = message.DelayUntilDate ?? DateTime.UtcNow, - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { var message = new ServiceBusMessage(body) { ContentType = "application/json", - MessageId = Guid.NewGuid().ToString() + MessageId = Guid.NewGuid().ToString(), + PartitionKey = organizationId }; await _eventSender.SendMessageAsync(message); diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs index 309b4a8409..4ac97df763 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs @@ -14,15 +14,21 @@ public class EventIntegrationEventWriteService : IEventWriteService, IAsyncDispo public async Task CreateAsync(IEvent e) { var body = JsonSerializer.Serialize(e); - await _eventIntegrationPublisher.PublishEventAsync(body: body); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: e.OrganizationId?.ToString()); } public async Task CreateManyAsync(IEnumerable events) { - var body = JsonSerializer.Serialize(events); - await _eventIntegrationPublisher.PublishEventAsync(body: body); - } + var eventList = events as IList ?? events.ToList(); + if (eventList.Count == 0) + { + return; + } + var organizationId = eventList[0].OrganizationId?.ToString(); + var body = JsonSerializer.Serialize(eventList); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: organizationId); + } public async ValueTask DisposeAsync() { await _eventIntegrationPublisher.DisposeAsync(); diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs index 0a8ab67554..8423652eb8 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs @@ -57,6 +57,7 @@ public class EventIntegrationHandler( { IntegrationType = integrationType, MessageId = messageId.ToString(), + OrganizationId = organizationId.ToString(), Configuration = config, RenderedTemplate = renderedTemplate, RetryCount = 0, diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs deleted file mode 100644 index a542e75a7b..0000000000 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventRouteService.cs +++ /dev/null @@ -1,34 +0,0 @@ -using Bit.Core.Models.Data; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Services; - -public class EventRouteService( - [FromKeyedServices("broadcast")] IEventWriteService broadcastEventWriteService, - [FromKeyedServices("storage")] IEventWriteService storageEventWriteService, - IFeatureService _featureService) : IEventWriteService -{ - public async Task CreateAsync(IEvent e) - { - if (_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations)) - { - await broadcastEventWriteService.CreateAsync(e); - } - else - { - await storageEventWriteService.CreateAsync(e); - } - } - - public async Task CreateManyAsync(IEnumerable e) - { - if (_featureService.IsEnabled(FeatureFlagKeys.EventBasedOrganizationIntegrations)) - { - await broadcastEventWriteService.CreateManyAsync(e); - } - else - { - await storageEventWriteService.CreateManyAsync(e); - } - } -} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md index de7ce3f7fd..7570d47211 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/README.md @@ -203,31 +203,17 @@ Currently, there are integrations / handlers for Slack, webhooks, and HTTP Event - The top-level object that enables a specific integration for the organization. - Includes any properties that apply to the entire integration across all events. - - For Slack, it consists of the token: `{ "Token": "xoxb-token-from-slack" }`. - - For webhooks, it is optional. Webhooks can either be configured at this level or the configuration level, - but the configuration level takes precedence. However, even though it is optional, an organization must - have a webhook `OrganizationIntegration` (even will a `null` `Configuration`) to enable configuration - via `OrganizationIntegrationConfiguration`. - - For HEC, it consists of the scheme, token, and URI: - -```json - { - "Scheme": "Bearer", - "Token": "Auth-token-from-HEC-service", - "Uri": "https://example.com/api" - } -``` + - For example, Slack stores the token in the `Configuration` which applies to every event, but stores the +channel id in the `Configuration` of the `OrganizationIntegrationConfiguration`. The token applies to the entire Slack +integration, but the channel could be configured differently depending on event type. + - See the table below for more examples / details on what is stored at which level. ### `OrganizationIntegrationConfiguration` - This contains the configurations specific to each `EventType` for the integration. - `Configuration` contains the event-specific configuration. - - For Slack, this would contain what channel to send the message to: `{ "channelId": "C123456" }` - - For webhooks, this is the URL the request should be sent to: `{ "url": "https://api.example.com" }` - - Optionally this also can include a `Scheme` and `Token` if this webhook needs Authentication. - - As stated above, all of this information can be specified here or at the `OrganizationIntegration` - level, but any properties declared here will take precedence over the ones above. - - For HEC, this must be null. HEC is configured only at the `OrganizationIntegration` level. + - Any properties at this level override the `Configuration` form the `OrganizationIntegration`. + - See the table below for examples of specific integrations. - `Template` contains a template string that is expected to be filled in with the contents of the actual event. - The tokens in the string are wrapped in `#` characters. For instance, the UserId would be `#UserId#`. - The `IntegrationTemplateProcessor` does the actual work of replacing these tokens with introspected values from @@ -245,6 +231,23 @@ Currently, there are integrations / handlers for Slack, webhooks, and HTTP Event - An array of `OrganizationIntegrationConfigurationDetails` is what the `EventIntegrationHandler` fetches from the database to determine what to publish at the integration level. +### Existing integrations and the configurations at each level + +The following table illustrates how each integration is configured and what exactly is stored in the `Configuration` +property at each level (`OrganizationIntegration` or `OrganizationIntegrationConfiguration`). Under +`OrganizationIntegration` the valid `OrganizationIntegrationStatus` are in bold, with an example of what would be +stored at each status. + +| **Integration** | **OrganizationIntegration** | **OrganizationIntegrationConfiguration** | +|------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------| +| CloudBillingSync | **Not Applicable** (not yet used) | **Not Applicable** (not yet used) | +| Scim | **Not Applicable** (not yet used) | **Not Applicable** (not yet used) | +| Slack | **Initiated**: `null`
**Completed**:
`{ "Token": "xoxb-token-from-slack" }` | `{ "channelId": "C123456" }` | +| Webhook | `null` or `{ "Scheme": "Bearer", "Token": "AUTH-TOKEN", "Uri": "https://example.com" }` | `null` or `{ "Scheme": "Bearer", "Token":"AUTH-TOKEN", "Uri": "https://example.com" }`

Whatever is defined at this level takes precedence | +| Hec | `{ "Scheme": "Bearer", "Token": "AUTH-TOKEN", "Uri": "https://example.com" }` | Always `null` | +| Datadog | `{ "ApiKey": "TheKey12345", "Uri": "https://api.us5.datadoghq.com/api/v1/events"}` | Always `null` | +| Teams | **Initiated**: `null`
**In Progress**:
`{ "TenantID": "tenant", "Teams": ["Id": "team", DisplayName: "MyTeam"]}`
**Completed**:
`{ "TenantID": "tenant", "Teams": ["Id": "team", DisplayName: "MyTeam"], "ServiceUrl":"https://example.com", ChannelId: "channel-1234"}` | Always `null` | + ## Filtering In addition to the ability to configure integrations mentioned above, organization admins can @@ -349,10 +352,20 @@ and event type. - This will be the deserialized version of the `MergedConfiguration` in `OrganizationIntegrationConfigurationDetails`. +A new row with the new integration should be added to this doc in the table above [Existing integrations +and the configurations at each level](#existing-integrations-and-the-configurations-at-each-level). + ## Request Models 1. Add a new case to the switch method in `OrganizationIntegrationRequestModel.Validate`. + - Additionally, add tests in `OrganizationIntegrationRequestModelTests` 2. Add a new case to the switch method in `OrganizationIntegrationConfigurationRequestModel.IsValidForType`. + - Additionally, add / update tests in `OrganizationIntegrationConfigurationRequestModelTests` + +## Response Model + +1. Add a new case to the switch method in `OrganizationIntegrationResponseModel.Status`. + - Additionally, add / update tests in `OrganizationIntegrationResponseModelTests` ## Integration Handler diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs index 3e20e34200..8976530cf4 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs @@ -122,7 +122,7 @@ public class RabbitMqService : IRabbitMqService body: body); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { await using var channel = await CreateChannelAsync(); var properties = new BasicProperties diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs index 2d29494afc..16c756c8c4 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackIntegrationHandler.cs @@ -6,14 +6,43 @@ public class SlackIntegrationHandler( ISlackService slackService) : IntegrationHandlerBase { + private static readonly HashSet _retryableErrors = new(StringComparer.Ordinal) + { + "internal_error", + "message_limit_exceeded", + "rate_limited", + "ratelimited", + "service_unavailable" + }; + public override async Task HandleAsync(IntegrationMessage message) { - await slackService.SendSlackMessageByChannelIdAsync( + var slackResponse = await slackService.SendSlackMessageByChannelIdAsync( message.Configuration.Token, message.RenderedTemplate, message.Configuration.ChannelId ); - return new IntegrationHandlerResult(success: true, message: message); + if (slackResponse is null) + { + return new IntegrationHandlerResult(success: false, message: message) + { + FailureReason = "Slack response was null" + }; + } + + if (slackResponse.Ok) + { + return new IntegrationHandlerResult(success: true, message: message); + } + + var result = new IntegrationHandlerResult(success: false, message: message) { FailureReason = slackResponse.Error }; + + if (_retryableErrors.Contains(slackResponse.Error)) + { + result.Retryable = true; + } + + return result; } } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs index f17185c4d3..7eec2ec374 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/SlackService.cs @@ -1,5 +1,6 @@ using System.Net.Http.Headers; using System.Net.Http.Json; +using System.Text.Json; using System.Web; using Bit.Core.Models.Slack; using Bit.Core.Settings; @@ -19,6 +20,7 @@ public class SlackService( private readonly string _slackApiBaseUrl = globalSettings.Slack.ApiBaseUrl; public const string HttpClientName = "SlackServiceHttpClient"; + private const string _slackOAuthBaseUri = "https://slack.com/oauth/v2/authorize"; public async Task GetChannelIdAsync(string token, string channelName) { @@ -70,32 +72,47 @@ public class SlackService( public async Task GetDmChannelByEmailAsync(string token, string email) { var userId = await GetUserIdByEmailAsync(token, email); - return await OpenDmChannel(token, userId); + return await OpenDmChannelAsync(token, userId); } - public string GetRedirectUrl(string redirectUrl) + public string GetRedirectUrl(string callbackUrl, string state) { - return $"https://slack.com/oauth/v2/authorize?client_id={_clientId}&scope={_scopes}&redirect_uri={redirectUrl}"; + var builder = new UriBuilder(_slackOAuthBaseUri); + var query = HttpUtility.ParseQueryString(builder.Query); + + query["client_id"] = _clientId; + query["scope"] = _scopes; + query["redirect_uri"] = callbackUrl; + query["state"] = state; + + builder.Query = query.ToString(); + return builder.ToString(); } public async Task ObtainTokenViaOAuth(string code, string redirectUrl) { + if (string.IsNullOrEmpty(code) || string.IsNullOrWhiteSpace(redirectUrl)) + { + logger.LogError("Error obtaining token via OAuth: Code and/or RedirectUrl were empty"); + return string.Empty; + } + var tokenResponse = await _httpClient.PostAsync($"{_slackApiBaseUrl}/oauth.v2.access", - new FormUrlEncodedContent(new[] - { + new FormUrlEncodedContent([ new KeyValuePair("client_id", _clientId), new KeyValuePair("client_secret", _clientSecret), new KeyValuePair("code", code), new KeyValuePair("redirect_uri", redirectUrl) - })); + ])); SlackOAuthResponse? result; try { result = await tokenResponse.Content.ReadFromJsonAsync(); } - catch + catch (JsonException ex) { + logger.LogError(ex, "Error parsing SlackOAuthResponse: invalid JSON"); result = null; } @@ -113,14 +130,25 @@ public class SlackService( return result.AccessToken; } - public async Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId) + public async Task SendSlackMessageByChannelIdAsync(string token, string message, + string channelId) { var payload = JsonContent.Create(new { channel = channelId, text = message }); var request = new HttpRequestMessage(HttpMethod.Post, $"{_slackApiBaseUrl}/chat.postMessage"); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); request.Content = payload; - await _httpClient.SendAsync(request); + var response = await _httpClient.SendAsync(request); + + try + { + return await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing Slack message response: invalid JSON"); + return null; + } } private async Task GetUserIdByEmailAsync(string token, string email) @@ -128,7 +156,16 @@ public class SlackService( var request = new HttpRequestMessage(HttpMethod.Get, $"{_slackApiBaseUrl}/users.lookupByEmail?email={email}"); request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); var response = await _httpClient.SendAsync(request); - var result = await response.Content.ReadFromJsonAsync(); + SlackUserResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing SlackUserResponse: invalid JSON"); + result = null; + } if (result is null) { @@ -144,7 +181,7 @@ public class SlackService( return result.User.Id; } - private async Task OpenDmChannel(string token, string userId) + private async Task OpenDmChannelAsync(string token, string userId) { if (string.IsNullOrEmpty(userId)) return string.Empty; @@ -154,7 +191,16 @@ public class SlackService( request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); request.Content = payload; var response = await _httpClient.SendAsync(request); - var result = await response.Content.ReadFromJsonAsync(); + SlackDmResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch (JsonException ex) + { + logger.LogError(ex, "Error parsing SlackDmResponse: invalid JSON"); + result = null; + } if (result is null) { diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsIntegrationHandler.cs new file mode 100644 index 0000000000..41d60bd69c --- /dev/null +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsIntegrationHandler.cs @@ -0,0 +1,41 @@ +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Microsoft.Rest; + +namespace Bit.Core.Services; + +public class TeamsIntegrationHandler( + ITeamsService teamsService) + : IntegrationHandlerBase +{ + public override async Task HandleAsync( + IntegrationMessage message) + { + try + { + await teamsService.SendMessageToChannelAsync( + serviceUri: message.Configuration.ServiceUrl, + message: message.RenderedTemplate, + channelId: message.Configuration.ChannelId + ); + + return new IntegrationHandlerResult(success: true, message: message); + } + catch (HttpOperationException ex) + { + var result = new IntegrationHandlerResult(success: false, message: message); + var statusCode = (int)ex.Response.StatusCode; + result.Retryable = statusCode is 429 or >= 500 and < 600; + result.FailureReason = ex.Message; + + return result; + } + catch (Exception ex) + { + var result = new IntegrationHandlerResult(success: false, message: message); + result.Retryable = false; + result.FailureReason = ex.Message; + + return result; + } + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsService.cs new file mode 100644 index 0000000000..f9911760bb --- /dev/null +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/TeamsService.cs @@ -0,0 +1,182 @@ +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Text.Json; +using System.Web; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Models.Teams; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Teams; +using Microsoft.Bot.Connector; +using Microsoft.Bot.Connector.Authentication; +using Microsoft.Bot.Schema; +using Microsoft.Extensions.Logging; +using TeamInfo = Bit.Core.Models.Teams.TeamInfo; + +namespace Bit.Core.Services; + +public class TeamsService( + IHttpClientFactory httpClientFactory, + IOrganizationIntegrationRepository integrationRepository, + GlobalSettings globalSettings, + ILogger logger) : ActivityHandler, ITeamsService +{ + private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName); + private readonly string _clientId = globalSettings.Teams.ClientId; + private readonly string _clientSecret = globalSettings.Teams.ClientSecret; + private readonly string _scopes = globalSettings.Teams.Scopes; + private readonly string _graphBaseUrl = globalSettings.Teams.GraphBaseUrl; + private readonly string _loginBaseUrl = globalSettings.Teams.LoginBaseUrl; + + public const string HttpClientName = "TeamsServiceHttpClient"; + + public string GetRedirectUrl(string redirectUrl, string state) + { + var query = HttpUtility.ParseQueryString(string.Empty); + query["client_id"] = _clientId; + query["response_type"] = "code"; + query["redirect_uri"] = redirectUrl; + query["response_mode"] = "query"; + query["scope"] = string.Join(" ", _scopes); + query["state"] = state; + + return $"{_loginBaseUrl}/common/oauth2/v2.0/authorize?{query}"; + } + + public async Task ObtainTokenViaOAuth(string code, string redirectUrl) + { + if (string.IsNullOrEmpty(code) || string.IsNullOrWhiteSpace(redirectUrl)) + { + logger.LogError("Error obtaining token via OAuth: Code and/or RedirectUrl were empty"); + return string.Empty; + } + + var request = new HttpRequestMessage(HttpMethod.Post, + $"{_loginBaseUrl}/common/oauth2/v2.0/token"); + + request.Content = new FormUrlEncodedContent(new Dictionary + { + { "client_id", _clientId }, + { "client_secret", _clientSecret }, + { "code", code }, + { "redirect_uri", redirectUrl }, + { "grant_type", "authorization_code" } + }); + + using var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + var errorText = await response.Content.ReadAsStringAsync(); + logger.LogError("Teams OAuth token exchange failed: {errorText}", errorText); + return string.Empty; + } + + TeamsOAuthResponse? result; + try + { + result = await response.Content.ReadFromJsonAsync(); + } + catch + { + result = null; + } + + if (result is null) + { + logger.LogError("Error obtaining token via OAuth: Unknown error"); + return string.Empty; + } + + return result.AccessToken; + } + + public async Task> GetJoinedTeamsAsync(string accessToken) + { + using var request = new HttpRequestMessage( + HttpMethod.Get, + $"{_graphBaseUrl}/me/joinedTeams"); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); + + using var response = await _httpClient.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + var errorText = await response.Content.ReadAsStringAsync(); + logger.LogError("Get Teams request failed: {errorText}", errorText); + return new List(); + } + + var result = await response.Content.ReadFromJsonAsync(); + + return result?.Value ?? []; + } + + public async Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message) + { + var credentials = new MicrosoftAppCredentials(_clientId, _clientSecret); + using var connectorClient = new ConnectorClient(serviceUri, credentials); + + var activity = new Activity + { + Type = ActivityTypes.Message, + Text = message + }; + + await connectorClient.Conversations.SendToConversationAsync(channelId, activity); + } + + protected override async Task OnInstallationUpdateAddAsync(ITurnContext turnContext, + CancellationToken cancellationToken) + { + var conversationId = turnContext.Activity.Conversation.Id; + var serviceUrl = turnContext.Activity.ServiceUrl; + var teamId = turnContext.Activity.TeamsGetTeamInfo().AadGroupId; + var tenantId = turnContext.Activity.Conversation.TenantId; + + if (!string.IsNullOrWhiteSpace(conversationId) && + !string.IsNullOrWhiteSpace(serviceUrl) && + Uri.TryCreate(serviceUrl, UriKind.Absolute, out var parsedUri) && + !string.IsNullOrWhiteSpace(teamId) && + !string.IsNullOrWhiteSpace(tenantId)) + { + await HandleIncomingAppInstallAsync( + conversationId: conversationId, + serviceUrl: parsedUri, + teamId: teamId, + tenantId: tenantId + ); + } + + await base.OnInstallationUpdateAddAsync(turnContext, cancellationToken); + } + + internal async Task HandleIncomingAppInstallAsync( + string conversationId, + Uri serviceUrl, + string teamId, + string tenantId) + { + var integration = await integrationRepository.GetByTeamsConfigurationTenantIdTeamId( + tenantId: tenantId, + teamId: teamId); + + if (integration?.Configuration is null) + { + return; + } + + var teamsConfig = JsonSerializer.Deserialize(integration.Configuration); + if (teamsConfig is null || teamsConfig.IsCompleted) + { + return; + } + + integration.Configuration = JsonSerializer.Serialize(teamsConfig with + { + ChannelId = conversationId, + ServiceUrl = serviceUrl + }); + + await integrationRepository.UpsertAsync(integration); + } +} diff --git a/src/Core/AdminConsole/Services/Implementations/EventService.cs b/src/Core/AdminConsole/Services/Implementations/EventService.cs index e0e0e040f1..77d481890e 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventService.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; using Bit.Core.AdminConsole.Models.Data.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.Identity; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -516,6 +517,135 @@ public class EventService : IEventService await _eventWriteService.CreateManyAsync(eventMessages); } + + public async Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + var orgUser = await _organizationUserRepository.GetByIdAsync((Guid)policy.OrganizationUserId); + + if (!CanUseEvents(orgAbilities, orgUser.OrganizationId)) + { + return; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + return; + } + + if (policy.OrganizationUserId != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = orgUser.OrganizationId, + Type = type, + GrantedServiceAccountId = policy.GrantedServiceAccountId, + ServiceAccountId = serviceAccountId, + UserId = policy.OrganizationUserId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + public async Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + + if (!CanUseEvents(orgAbilities, policy.Group.OrganizationId)) + { + return; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + return; + } + + if (policy.GroupId != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = policy.Group.OrganizationId, + Type = type, + GrantedServiceAccountId = policy.GrantedServiceAccountId, + ServiceAccountId = serviceAccountId, + GroupId = policy.GroupId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + public async Task LogServiceAccountEventAsync(Guid userId, List serviceAccounts, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + var orgAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var eventMessages = new List(); + + foreach (var serviceAccount in serviceAccounts) + { + if (!CanUseEvents(orgAbilities, serviceAccount.OrganizationId)) + { + continue; + } + + var (actingUserId, serviceAccountId) = MapIdentityClientType(userId, identityClientType); + + if (actingUserId is null && serviceAccountId is null) + { + continue; + } + + if (serviceAccount != null) + { + var e = new EventMessage(_currentContext) + { + OrganizationId = serviceAccount.OrganizationId, + Type = type, + GrantedServiceAccountId = serviceAccount.Id, + ServiceAccountId = serviceAccountId, + ActingUserId = actingUserId, + Date = date.GetValueOrDefault(DateTime.UtcNow) + }; + eventMessages.Add(e); + } + } + + if (eventMessages.Any()) + { + await _eventWriteService.CreateManyAsync(eventMessages); + } + } + + private (Guid? actingUserId, Guid? serviceAccountId) MapIdentityClientType( + Guid userId, IdentityClientType identityClientType) + { + if (identityClientType == IdentityClientType.Organization) + { + return (null, null); + } + + return identityClientType switch + { + IdentityClientType.User => (userId, null), + IdentityClientType.ServiceAccount => (null, userId), + _ => throw new InvalidOperationException("Unknown identity client type.") + }; + } + + private async Task GetProviderIdAsync(Guid? orgId) { if (_currentContext == null || !orgId.HasValue) diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs index e8dd495205..6ecea7d234 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopEventService.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Interfaces; +using Bit.Core.Auth.Identity; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.SecretsManager.Entities; @@ -139,4 +140,19 @@ public class NoopEventService : IEventService { return Task.FromResult(0); } + + public Task LogServiceAccountPeopleEventAsync(Guid userId, UserServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } + + public Task LogServiceAccountGroupEventAsync(Guid userId, GroupServiceAccountAccessPolicy policy, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } + + public Task LogServiceAccountEventAsync(Guid userId, List serviceAccount, EventType type, IdentityClientType identityClientType, DateTime? date = null) + { + return Task.FromResult(0); + } } diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs index 2bf4a54a87..3782b30e3f 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopProviderService.cs @@ -3,7 +3,7 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Models.Business.Provider; -using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Entities; using Bit.Core.Models.Business; @@ -11,7 +11,7 @@ namespace Bit.Core.AdminConsole.Services.NoopImplementations; public class NoopProviderService : IProviderService { - public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TaxInfo taxInfo, TokenizedPaymentSource tokenizedPaymentSource = null) => throw new NotImplementedException(); + public Task CompleteSetupAsync(Provider provider, Guid ownerUserId, string token, string key, TokenizedPaymentMethod paymentMethod, BillingAddress billingAddress) => throw new NotImplementedException(); public Task UpdateAsync(Provider provider, bool updateBilling = false) => throw new NotImplementedException(); diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs index c34c073e87..a54df94814 100644 --- a/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopSlackService.cs @@ -1,4 +1,5 @@ -using Bit.Core.Services; +using Bit.Core.Models.Slack; +using Bit.Core.Services; namespace Bit.Core.AdminConsole.Services.NoopImplementations; @@ -19,14 +20,15 @@ public class NoopSlackService : ISlackService return Task.FromResult(string.Empty); } - public string GetRedirectUrl(string redirectUrl) + public string GetRedirectUrl(string callbackUrl, string state) { return string.Empty; } - public Task SendSlackMessageByChannelIdAsync(string token, string message, string channelId) + public Task SendSlackMessageByChannelIdAsync(string token, string message, + string channelId) { - return Task.FromResult(0); + return Task.FromResult(null); } public Task ObtainTokenViaOAuth(string code, string redirectUrl) diff --git a/src/Core/AdminConsole/Services/NoopImplementations/NoopTeamsService.cs b/src/Core/AdminConsole/Services/NoopImplementations/NoopTeamsService.cs new file mode 100644 index 0000000000..fafb23f570 --- /dev/null +++ b/src/Core/AdminConsole/Services/NoopImplementations/NoopTeamsService.cs @@ -0,0 +1,27 @@ +using Bit.Core.Models.Teams; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.Services.NoopImplementations; + +public class NoopTeamsService : ITeamsService +{ + public string GetRedirectUrl(string callbackUrl, string state) + { + return string.Empty; + } + + public Task ObtainTokenViaOAuth(string code, string redirectUrl) + { + return Task.FromResult(string.Empty); + } + + public Task> GetJoinedTeamsAsync(string accessToken) + { + return Task.FromResult>(Array.Empty()); + } + + public Task SendMessageToChannelAsync(Uri serviceUri, string channelId, string message) + { + return Task.CompletedTask; + } +} diff --git a/src/Core/AdminConsole/Services/OrganizationFactory.cs b/src/Core/AdminConsole/Services/OrganizationFactory.cs index afb3931ec4..f5df3327b1 100644 --- a/src/Core/AdminConsole/Services/OrganizationFactory.cs +++ b/src/Core/AdminConsole/Services/OrganizationFactory.cs @@ -61,6 +61,7 @@ public static class OrganizationFactory claimsPrincipal.GetValue(OrganizationLicenseConstants.UseOrganizationDomains), UseAdminSponsoredFamilies = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAdminSponsoredFamilies), + UseAutomaticUserConfirmation = claimsPrincipal.GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation), }; public static Organization Create( @@ -110,5 +111,6 @@ public static class OrganizationFactory UseRiskInsights = license.UseRiskInsights, UseOrganizationDomains = license.UseOrganizationDomains, UseAdminSponsoredFamilies = license.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = license.UseAutomaticUserConfirmation }; } diff --git a/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs new file mode 100644 index 0000000000..84e63f2a20 --- /dev/null +++ b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs @@ -0,0 +1,81 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Exceptions; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.Utilities; + +public static class PolicyDataValidator +{ + /// + /// Validates and serializes policy data based on the policy type. + /// + /// The policy data to validate + /// The type of policy + /// Serialized JSON string if data is valid, null if data is null or empty + /// Thrown when data validation fails + public static string? ValidateAndSerialize(Dictionary? data, PolicyType policyType) + { + if (data == null || data.Count == 0) + { + return null; + } + + try + { + var json = JsonSerializer.Serialize(data); + + switch (policyType) + { + case PolicyType.MasterPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.SendOptions: + CoreHelpers.LoadClassFromJsonData(json); + break; + case PolicyType.ResetPassword: + CoreHelpers.LoadClassFromJsonData(json); + break; + } + + return json; + } + catch (JsonException ex) + { + var fieldInfo = !string.IsNullOrEmpty(ex.Path) ? $": field '{ex.Path}' has invalid type" : ""; + throw new BadRequestException($"Invalid data for {policyType} policy{fieldInfo}."); + } + } + + /// + /// Validates and deserializes policy metadata based on the policy type. + /// + /// The policy metadata to validate + /// The type of policy + /// Deserialized metadata model, or EmptyMetadataModel if metadata is null, empty, or validation fails + public static IPolicyMetadataModel ValidateAndDeserializeMetadata(Dictionary? metadata, PolicyType policyType) + { + if (metadata == null || metadata.Count == 0) + { + return new EmptyMetadataModel(); + } + + try + { + var json = JsonSerializer.Serialize(metadata); + + return policyType switch + { + PolicyType.OrganizationDataOwnership => + CoreHelpers.LoadClassFromJsonData(json), + _ => new EmptyMetadataModel() + }; + } + catch (JsonException) + { + return new EmptyMetadataModel(); + } + } +} diff --git a/src/Core/Auth/Entities/WebAuthnCredential.cs b/src/Core/Auth/Entities/WebAuthnCredential.cs index ecc763088d..595ecfc041 100644 --- a/src/Core/Auth/Entities/WebAuthnCredential.cs +++ b/src/Core/Auth/Entities/WebAuthnCredential.cs @@ -22,13 +22,30 @@ public class WebAuthnCredential : ITableObject [MaxLength(20)] public string Type { get; set; } public Guid AaGuid { get; set; } + + /// + /// User key encrypted with this WebAuthn credential's public key (EncryptedPublicKey field). + /// [MaxLength(2000)] public string EncryptedUserKey { get; set; } + + /// + /// Private key encrypted with an external key for secure storage. + /// [MaxLength(2000)] public string EncryptedPrivateKey { get; set; } + + /// + /// Public key encrypted with the user key for key rotation. + /// [MaxLength(2000)] public string EncryptedPublicKey { get; set; } + + /// + /// Indicates whether this credential supports PRF (Pseudo-Random Function) extension. + /// public bool SupportsPrf { get; set; } + public DateTime CreationDate { get; internal set; } = DateTime.UtcNow; public DateTime RevisionDate { get; internal set; } = DateTime.UtcNow; diff --git a/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs b/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs index 70aba8ef75..f6ef3a5dd0 100644 --- a/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs +++ b/src/Core/Auth/Identity/TokenProviders/EmailTokenProvider.cs @@ -65,7 +65,7 @@ public class EmailTokenProvider : IUserTwoFactorTokenProvider } var code = Encoding.UTF8.GetString(cachedValue); - var valid = string.Equals(token, code); + var valid = CoreHelpers.FixedTimeEquals(token, code); if (valid) { await _distributedCache.RemoveAsync(cacheKey); diff --git a/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs b/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs index b6280e13fe..ae394f817e 100644 --- a/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs +++ b/src/Core/Auth/Identity/TokenProviders/OtpTokenProvider/OtpTokenProvider.cs @@ -64,7 +64,7 @@ public class OtpTokenProvider( } var code = Encoding.UTF8.GetString(cachedValue); - var valid = string.Equals(token, code); + var valid = CoreHelpers.FixedTimeEquals(token, code); if (valid) { await _distributedCache.RemoveAsync(cacheKey); diff --git a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs index bd8542e8bf..aa8a298200 100644 --- a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs +++ b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs @@ -1,5 +1,5 @@ using System.Text.Json.Serialization; -using Bit.Core.KeyManagement.Models.Response; +using Bit.Core.KeyManagement.Models.Api.Response; using Bit.Core.Models.Api; namespace Bit.Core.Auth.Models.Api.Response; diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index fe8d9bdd6e..1a35585b2c 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -3,9 +3,11 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; @@ -24,7 +26,9 @@ public class SsoConfigService : ISsoConfigService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; + private readonly IFeatureService _featureService; private readonly ISavePolicyCommand _savePolicyCommand; + private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; public SsoConfigService( ISsoConfigRepository ssoConfigRepository, @@ -32,14 +36,18 @@ public class SsoConfigService : ISsoConfigService IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IEventService eventService, - ISavePolicyCommand savePolicyCommand) + IFeatureService featureService, + ISavePolicyCommand savePolicyCommand, + IVNextSavePolicyCommand vNextSavePolicyCommand) { _ssoConfigRepository = ssoConfigRepository; _policyRepository = policyRepository; _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _eventService = eventService; + _featureService = featureService; _savePolicyCommand = savePolicyCommand; + _vNextSavePolicyCommand = vNextSavePolicyCommand; } public async Task SaveAsync(SsoConfig config, Organization organization) @@ -67,13 +75,12 @@ public class SsoConfigService : ISsoConfigService // Automatically enable account recovery, SSO required, and single org policies if trusted device encryption is selected if (config.GetData().MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption) { - - await _savePolicyCommand.SaveAsync(new() + var singleOrgPolicy = new PolicyUpdate { OrganizationId = config.OrganizationId, Type = PolicyType.SingleOrg, Enabled = true - }); + }; var resetPasswordPolicy = new PolicyUpdate { @@ -82,14 +89,27 @@ public class SsoConfigService : ISsoConfigService Enabled = true, }; resetPasswordPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true }); - await _savePolicyCommand.SaveAsync(resetPasswordPolicy); - await _savePolicyCommand.SaveAsync(new() + var requireSsoPolicy = new PolicyUpdate { OrganizationId = config.OrganizationId, Type = PolicyType.RequireSso, Enabled = true - }); + }; + + if (_featureService.IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor)) + { + var performedBy = new SystemUser(EventSystemUser.Unknown); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(singleOrgPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(resetPasswordPolicy, performedBy)); + await _vNextSavePolicyCommand.SaveAsync(new SavePolicyModel(requireSsoPolicy, performedBy)); + } + else + { + await _savePolicyCommand.SaveAsync(singleOrgPolicy); + await _savePolicyCommand.SaveAsync(resetPasswordPolicy); + await _savePolicyCommand.SaveAsync(requireSsoPolicy); + } } await LogEventsAsync(config, oldConfig); diff --git a/src/Core/Billing/Commands/BillingCommandResult.cs b/src/Core/Billing/Commands/BillingCommandResult.cs index 3238ab4107..db260e7038 100644 --- a/src/Core/Billing/Commands/BillingCommandResult.cs +++ b/src/Core/Billing/Commands/BillingCommandResult.cs @@ -1,5 +1,4 @@ -#nullable enable -using OneOf; +using OneOf; namespace Bit.Core.Billing.Commands; @@ -20,18 +19,38 @@ public record Unhandled(Exception? Exception = null, string Response = "Somethin /// ///
/// The successful result type of the operation. -public class BillingCommandResult : OneOfBase +public class BillingCommandResult(OneOf input) + : OneOfBase(input) { - private BillingCommandResult(OneOf input) : base(input) { } - public static implicit operator BillingCommandResult(T output) => new(output); public static implicit operator BillingCommandResult(BadRequest badRequest) => new(badRequest); public static implicit operator BillingCommandResult(Conflict conflict) => new(conflict); public static implicit operator BillingCommandResult(Unhandled unhandled) => new(unhandled); + public BillingCommandResult Map(Func f) + => Match( + value => new BillingCommandResult(f(value)), + badRequest => new BillingCommandResult(badRequest), + conflict => new BillingCommandResult(conflict), + unhandled => new BillingCommandResult(unhandled)); + public Task TapAsync(Func f) => Match( f, _ => Task.CompletedTask, _ => Task.CompletedTask, _ => Task.CompletedTask); } + +public static class BillingCommandResultExtensions +{ + public static async Task> AndThenAsync( + this Task> task, Func>> binder) + { + var result = await task; + return await result.Match( + binder, + badRequest => Task.FromResult(new BillingCommandResult(badRequest)), + conflict => Task.FromResult(new BillingCommandResult(conflict)), + unhandled => Task.FromResult(new BillingCommandResult(unhandled))); + } +} diff --git a/src/Core/Billing/Constants/BitPayConstants.cs b/src/Core/Billing/Constants/BitPayConstants.cs new file mode 100644 index 0000000000..a1b2ff6f5b --- /dev/null +++ b/src/Core/Billing/Constants/BitPayConstants.cs @@ -0,0 +1,14 @@ +namespace Bit.Core.Billing.Constants; + +public static class BitPayConstants +{ + public static class InvoiceStatuses + { + public const string Complete = "complete"; + } + + public static class PosDataKeys + { + public const string AccountCredit = "accountCredit:1"; + } +} diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index 131adfedf8..11f043fc69 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -22,6 +22,8 @@ public static class StripeConstants { public const string LegacyMSPDiscount = "msp-discount-35"; public const string SecretsManagerStandalone = "sm-standalone"; + public const string Milestone2SubscriptionDiscount = "milestone-2c"; + public const string Milestone3SubscriptionDiscount = "milestone-3"; public static class MSPDiscounts { diff --git a/src/Core/Billing/Enums/PlanCadenceType.cs b/src/Core/Billing/Enums/PlanCadenceType.cs new file mode 100644 index 0000000000..9e6fa69832 --- /dev/null +++ b/src/Core/Billing/Enums/PlanCadenceType.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Enums; + +public enum PlanCadenceType +{ + Annually, + Monthly +} diff --git a/src/Core/Billing/Enums/PlanType.cs b/src/Core/Billing/Enums/PlanType.cs index e88a73af16..0f910c4980 100644 --- a/src/Core/Billing/Enums/PlanType.cs +++ b/src/Core/Billing/Enums/PlanType.cs @@ -18,8 +18,8 @@ public enum PlanType : byte EnterpriseAnnually2019 = 5, [Display(Name = "Custom")] Custom = 6, - [Display(Name = "Families")] - FamiliesAnnually = 7, + [Display(Name = "Families 2025")] + FamiliesAnnually2025 = 7, [Display(Name = "Teams (Monthly) 2020")] TeamsMonthly2020 = 8, [Display(Name = "Teams (Annually) 2020")] @@ -48,4 +48,6 @@ public enum PlanType : byte EnterpriseAnnually = 20, [Display(Name = "Teams Starter")] TeamsStarter = 21, + [Display(Name = "Families")] + FamiliesAnnually = 22, } diff --git a/src/Core/Billing/Extensions/BillingExtensions.cs b/src/Core/Billing/Extensions/BillingExtensions.cs index 7f81bfd33f..2dae0c2025 100644 --- a/src/Core/Billing/Extensions/BillingExtensions.cs +++ b/src/Core/Billing/Extensions/BillingExtensions.cs @@ -15,7 +15,7 @@ public static class BillingExtensions => planType switch { PlanType.Custom or PlanType.Free => ProductTierType.Free, - PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2019 => ProductTierType.Families, + PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually2019 => ProductTierType.Families, PlanType.TeamsStarter or PlanType.TeamsStarter2023 => ProductTierType.TeamsStarter, _ when planType.ToString().Contains("Teams") => ProductTierType.Teams, _ when planType.ToString().Contains("Enterprise") => ProductTierType.Enterprise, diff --git a/src/Core/Billing/Extensions/InvoiceExtensions.cs b/src/Core/Billing/Extensions/InvoiceExtensions.cs index bb9f7588bf..d62959c09a 100644 --- a/src/Core/Billing/Extensions/InvoiceExtensions.cs +++ b/src/Core/Billing/Extensions/InvoiceExtensions.cs @@ -64,10 +64,12 @@ public static class InvoiceExtensions } } + var tax = invoice.TotalTaxes?.Sum(invoiceTotalTax => invoiceTotalTax.Amount) ?? 0; + // Add fallback tax from invoice-level tax if present and not already included - if (invoice.Tax.HasValue && invoice.Tax.Value > 0) + if (tax > 0) { - var taxAmount = invoice.Tax.Value / 100m; + var taxAmount = tax / 100m; items.Add($"1 × Tax (at ${taxAmount:F2} / month)"); } diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index b4e37f0151..d6593f5365 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -9,7 +9,7 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; -using Bit.Core.Billing.Tax.Commands; +using Bit.Core.Billing.Subscriptions.Commands; using Bit.Core.Billing.Tax.Services; using Bit.Core.Billing.Tax.Services.Implementations; @@ -28,11 +28,13 @@ public static class ServiceCollectionExtensions services.AddTransient(); services.AddLicenseServices(); services.AddPricingClient(); - services.AddTransient(); services.AddPaymentOperations(); services.AddOrganizationLicenseCommandsQueries(); services.AddPremiumCommands(); + services.AddTransient(); services.AddTransient(); + services.AddTransient(); + services.AddTransient(); } private static void AddOrganizationLicenseCommandsQueries(this IServiceCollection services) @@ -46,5 +48,6 @@ public static class ServiceCollectionExtensions { services.AddScoped(); services.AddScoped(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Extensions/SubscriptionExtensions.cs b/src/Core/Billing/Extensions/SubscriptionExtensions.cs new file mode 100644 index 0000000000..383bd32d53 --- /dev/null +++ b/src/Core/Billing/Extensions/SubscriptionExtensions.cs @@ -0,0 +1,25 @@ +using Stripe; + +namespace Bit.Core.Billing.Extensions; + +public static class SubscriptionExtensions +{ + /* + * For the time being, this is the simplest migration approach from v45 to v48 as + * we do not support multi-cadence subscriptions. Each subscription item should be on the + * same billing cycle. If this changes, we'll need a significantly more robust approach. + * + * Because we can't guarantee a subscription will have items, this has to be nullable. + */ + public static (DateTime? Start, DateTime? End)? GetCurrentPeriod(this Subscription subscription) + { + var item = subscription.Items?.FirstOrDefault(); + return item is null ? null : (item.CurrentPeriodStart, item.CurrentPeriodEnd); + } + + public static DateTime? GetCurrentPeriodStart(this Subscription subscription) => + subscription.Items?.FirstOrDefault()?.CurrentPeriodStart; + + public static DateTime? GetCurrentPeriodEnd(this Subscription subscription) => + subscription.Items?.FirstOrDefault()?.CurrentPeriodEnd; +} diff --git a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs b/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs deleted file mode 100644 index d00b5b46a4..0000000000 --- a/src/Core/Billing/Extensions/UpcomingInvoiceOptionsExtensions.cs +++ /dev/null @@ -1,35 +0,0 @@ -using Stripe; - -namespace Bit.Core.Billing.Extensions; - -public static class UpcomingInvoiceOptionsExtensions -{ - /// - /// Attempts to enable automatic tax for given upcoming invoice options. - /// - /// - /// The existing customer to which the upcoming invoice belongs. - /// The existing subscription to which the upcoming invoice belongs. - /// Returns true when successful, false when conditions are not met. - public static bool EnableAutomaticTax( - this UpcomingInvoiceOptions options, - Customer customer, - Subscription subscription) - { - if (subscription != null && subscription.AutomaticTax.Enabled) - { - return false; - } - - // We might only need to check the automatic tax status. - if (!customer.HasRecognizedTaxLocation() && string.IsNullOrWhiteSpace(customer.Address?.Country)) - { - return false; - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - options.SubscriptionDefaultTaxRates = []; - - return true; - } -} diff --git a/src/Core/Billing/Licenses/LicenseConstants.cs b/src/Core/Billing/Licenses/LicenseConstants.cs index cdfac76614..79ac94be62 100644 --- a/src/Core/Billing/Licenses/LicenseConstants.cs +++ b/src/Core/Billing/Licenses/LicenseConstants.cs @@ -43,6 +43,7 @@ public static class OrganizationLicenseConstants public const string Trial = nameof(Trial); public const string UseAdminSponsoredFamilies = nameof(UseAdminSponsoredFamilies); public const string UseOrganizationDomains = nameof(UseOrganizationDomains); + public const string UseAutomaticUserConfirmation = nameof(UseAutomaticUserConfirmation); } public static class UserLicenseConstants diff --git a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs index 1e049d7f03..e9aadbe758 100644 --- a/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs +++ b/src/Core/Billing/Licenses/Services/Implementations/OrganizationLicenseClaimsFactory.cs @@ -56,6 +56,7 @@ public class OrganizationLicenseClaimsFactory : ILicenseClaimsFactory SecretsManager != null; + public bool AutomaticUserConfirmation { get; init; } + public bool HasNonSeatBasedPasswordManagerPlan() => PasswordManager is { StripePlanId: not null and not "", StripeSeatPlanId: null or "" }; diff --git a/src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs b/src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs new file mode 100644 index 0000000000..77e238e98e --- /dev/null +++ b/src/Core/Billing/Models/StaticStore/Plans/Families2025Plan.cs @@ -0,0 +1,47 @@ +using Bit.Core.Billing.Enums; +using Bit.Core.Models.StaticStore; + +namespace Bit.Core.Billing.Models.StaticStore.Plans; + +public record Families2025Plan : Plan +{ + public Families2025Plan() + { + Type = PlanType.FamiliesAnnually2025; + ProductTier = ProductTierType.Families; + Name = "Families 2025"; + IsAnnual = true; + NameLocalizationKey = "planNameFamilies"; + DescriptionLocalizationKey = "planDescFamilies"; + + TrialPeriodDays = 7; + + HasSelfHost = true; + HasTotp = true; + UsersGetPremium = true; + + UpgradeSortOrder = 1; + DisplaySortOrder = 1; + + PasswordManager = new Families2025PasswordManagerFeatures(); + } + + private record Families2025PasswordManagerFeatures : PasswordManagerPlanFeatures + { + public Families2025PasswordManagerFeatures() + { + BaseSeats = 6; + BaseStorageGb = 1; + MaxSeats = 6; + + HasAdditionalStorageOption = true; + + StripePlanId = "2020-families-org-annually"; + StripeStoragePlanId = "personal-storage-gb-annually"; + BasePrice = 40; + AdditionalStoragePricePerGb = 4; + + AllowSeatAutoscale = false; + } + } +} diff --git a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs b/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs index 8c71e50fa4..b2edc1168b 100644 --- a/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs +++ b/src/Core/Billing/Models/StaticStore/Plans/FamiliesPlan.cs @@ -23,12 +23,12 @@ public record FamiliesPlan : Plan UpgradeSortOrder = 1; DisplaySortOrder = 1; - PasswordManager = new TeamsPasswordManagerFeatures(); + PasswordManager = new FamiliesPasswordManagerFeatures(); } - private record TeamsPasswordManagerFeatures : PasswordManagerPlanFeatures + private record FamiliesPasswordManagerFeatures : PasswordManagerPlanFeatures { - public TeamsPasswordManagerFeatures() + public FamiliesPasswordManagerFeatures() { BaseSeats = 6; BaseStorageGb = 1; diff --git a/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs new file mode 100644 index 0000000000..89d301c22a --- /dev/null +++ b/src/Core/Billing/Organizations/Commands/PreviewOrganizationTaxCommand.cs @@ -0,0 +1,402 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Enums; +using Bit.Core.Services; +using Bit.Core.Utilities; +using Microsoft.Extensions.Logging; +using OneOf; +using Stripe; + +namespace Bit.Core.Billing.Organizations.Commands; + +using static Core.Constants; +using static StripeConstants; + +public interface IPreviewOrganizationTaxCommand +{ + Task> Run( + OrganizationSubscriptionPurchase purchase, + BillingAddress billingAddress); + + Task> Run( + Organization organization, + OrganizationSubscriptionPlanChange planChange, + BillingAddress billingAddress); + + Task> Run( + Organization organization, + OrganizationSubscriptionUpdate update); +} + +public class PreviewOrganizationTaxCommand( + ILogger logger, + IPricingClient pricingClient, + IStripeAdapter stripeAdapter) + : BaseBillingCommand(logger), IPreviewOrganizationTaxCommand +{ + public Task> Run( + OrganizationSubscriptionPurchase purchase, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + var plan = await pricingClient.GetPlanOrThrow(purchase.PlanType); + + var options = GetBaseOptions(billingAddress, purchase.Tier != ProductTierType.Families); + + var items = new List(); + + switch (purchase) + { + case { PasswordManager.Sponsored: true }: + var sponsoredPlan = StaticStore.GetSponsoredPlan(PlanSponsorshipType.FamiliesForEnterprise); + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = sponsoredPlan.StripePlanId, + Quantity = 1 + }); + break; + + case { SecretsManager.Standalone: true }: + items.AddRange([ + new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.PasswordManager.StripeSeatPlanId, + Quantity = purchase.PasswordManager.Seats + }, + new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeSeatPlanId, + Quantity = purchase.SecretsManager.Seats + } + ]); + options.Discounts = + [ + new InvoiceDiscountOptions + { + Coupon = CouponIDs.SecretsManagerStandalone + } + ]; + break; + + default: + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.HasNonSeatBasedPasswordManagerPlan() + ? plan.PasswordManager.StripePlanId + : plan.PasswordManager.StripeSeatPlanId, + Quantity = purchase.PasswordManager.Seats + }); + + if (purchase.PasswordManager.AdditionalStorage > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.PasswordManager.StripeStoragePlanId, + Quantity = purchase.PasswordManager.AdditionalStorage + }); + } + + if (purchase.SecretsManager is { Seats: > 0 }) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeSeatPlanId, + Quantity = purchase.SecretsManager.Seats + }); + + if (purchase.SecretsManager.AdditionalServiceAccounts > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = plan.SecretsManager.StripeServiceAccountPlanId, + Quantity = purchase.SecretsManager.AdditionalServiceAccounts + }); + } + } + + break; + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + public Task> Run( + Organization organization, + OrganizationSubscriptionPlanChange planChange, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + if (organization.PlanType.GetProductTier() == ProductTierType.Free) + { + var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); + + var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType); + + var quantity = newPlan.HasNonSeatBasedPasswordManagerPlan() ? 1 : 2; + + var items = new List + { + new () + { + Price = newPlan.HasNonSeatBasedPasswordManagerPlan() + ? newPlan.PasswordManager.StripePlanId + : newPlan.PasswordManager.StripeSeatPlanId, + Quantity = quantity + } + }; + + if (organization.UseSecretsManager && planChange.Tier != ProductTierType.Families) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeSeatPlanId, + Quantity = 2 + }); + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + } + else + { + if (organization is not + { + GatewayCustomerId: not null, + GatewaySubscriptionId: not null + }) + { + return new BadRequest("Organization does not have a subscription."); + } + + var options = GetBaseOptions(billingAddress, planChange.Tier != ProductTierType.Families); + + var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer"] }); + + if (subscription.Customer.Discount != null) + { + options.Discounts = + [ + new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id } + ]; + } + + var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); + var newPlan = await pricingClient.GetPlanOrThrow(planChange.PlanType); + + var subscriptionItemsByPriceId = + subscription.Items.ToDictionary(subscriptionItem => subscriptionItem.Price.Id); + + var items = new List(); + + var passwordManagerSeats = subscriptionItemsByPriceId[ + currentPlan.HasNonSeatBasedPasswordManagerPlan() + ? currentPlan.PasswordManager.StripePlanId + : currentPlan.PasswordManager.StripeSeatPlanId]; + + var quantity = currentPlan.HasNonSeatBasedPasswordManagerPlan() && + !newPlan.HasNonSeatBasedPasswordManagerPlan() + ? (long)organization.Seats! + : passwordManagerSeats.Quantity; + + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.HasNonSeatBasedPasswordManagerPlan() + ? newPlan.PasswordManager.StripePlanId + : newPlan.PasswordManager.StripeSeatPlanId, + Quantity = quantity + }); + + var hasStorage = + subscriptionItemsByPriceId.TryGetValue(newPlan.PasswordManager.StripeStoragePlanId, + out var storage); + + if (hasStorage && storage != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.PasswordManager.StripeStoragePlanId, + Quantity = storage.Quantity + }); + } + + var hasSecretsManagerSeats = subscriptionItemsByPriceId.TryGetValue( + newPlan.SecretsManager.StripeSeatPlanId, + out var secretsManagerSeats); + + if (hasSecretsManagerSeats && secretsManagerSeats != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeSeatPlanId, + Quantity = secretsManagerSeats.Quantity + }); + + var hasServiceAccounts = + subscriptionItemsByPriceId.TryGetValue(newPlan.SecretsManager.StripeServiceAccountPlanId, + out var serviceAccounts); + + if (hasServiceAccounts && serviceAccounts != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = newPlan.SecretsManager.StripeServiceAccountPlanId, + Quantity = serviceAccounts.Quantity + }); + } + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + } + }); + + public Task> Run( + Organization organization, + OrganizationSubscriptionUpdate update) + => HandleAsync<(decimal, decimal)>(async () => + { + if (organization is not + { + GatewayCustomerId: not null, + GatewaySubscriptionId: not null + }) + { + return new BadRequest("Organization does not have a subscription."); + } + + var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.tax_ids"] }); + + var options = GetBaseOptions(subscription.Customer, + organization.GetProductUsageType() == ProductUsageType.Business); + + if (subscription.Customer.Discount != null) + { + options.Discounts = + [ + new InvoiceDiscountOptions { Coupon = subscription.Customer.Discount.Coupon.Id } + ]; + } + + var currentPlan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + var items = new List(); + + if (update.PasswordManager?.Seats != null) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.HasNonSeatBasedPasswordManagerPlan() + ? currentPlan.PasswordManager.StripePlanId + : currentPlan.PasswordManager.StripeSeatPlanId, + Quantity = update.PasswordManager.Seats + }); + } + + if (update.PasswordManager?.AdditionalStorage is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.PasswordManager.StripeStoragePlanId, + Quantity = update.PasswordManager.AdditionalStorage + }); + } + + if (update.SecretsManager?.Seats is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.SecretsManager.StripeSeatPlanId, + Quantity = update.SecretsManager.Seats + }); + + if (update.SecretsManager.AdditionalServiceAccounts is > 0) + { + items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = currentPlan.SecretsManager.StripeServiceAccountPlanId, + Quantity = update.SecretsManager.AdditionalServiceAccounts + }); + } + } + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = items }; + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + private static (decimal, decimal) GetAmounts(Invoice invoice) => ( + Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100, + Convert.ToDecimal(invoice.Total) / 100); + + private static InvoiceCreatePreviewOptions GetBaseOptions( + OneOf addressChoice, + bool businessUse) + { + var country = addressChoice.Match( + customer => customer.Address.Country, + billingAddress => billingAddress.Country + ); + + var postalCode = addressChoice.Match( + customer => customer.Address.PostalCode, + billingAddress => billingAddress.PostalCode); + + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, + Currency = "usd", + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions { Country = country, PostalCode = postalCode }, + TaxExempt = businessUse && country != CountryAbbreviations.UnitedStates + ? TaxExempt.Reverse + : TaxExempt.None + } + }; + + var taxId = addressChoice.Match( + customer => + { + var taxId = customer.TaxIds?.FirstOrDefault(); + return taxId != null ? new TaxID(taxId.Type, taxId.Value) : null; + }, + billingAddress => billingAddress.TaxId); + + if (taxId == null) + { + return options; + } + + options.CustomerDetails.TaxIds = + [ + new InvoiceCustomerDetailsTaxIdOptions { Type = taxId.Code, Value = taxId.Value } + ]; + + if (taxId.Code == TaxIdType.SpanishNIF) + { + options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions + { + Type = TaxIdType.EUVAT, + Value = $"ES{taxId.Value}" + }); + } + + return options; + } +} diff --git a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs index fde95f2e70..1dfd786210 100644 --- a/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs +++ b/src/Core/Billing/Organizations/Commands/UpdateOrganizationLicenseCommand.cs @@ -1,5 +1,7 @@ using System.Text.Json; using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Licenses; +using Bit.Core.Billing.Licenses.Extensions; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; @@ -52,6 +54,12 @@ public class UpdateOrganizationLicenseCommand : IUpdateOrganizationLicenseComman throw new BadRequestException(exception); } + var useAutomaticUserConfirmation = claimsPrincipal? + .GetValue(OrganizationLicenseConstants.UseAutomaticUserConfirmation) ?? false; + + selfHostedOrganization.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + license.UseAutomaticUserConfirmation = useAutomaticUserConfirmation; + await WriteLicenseFileAsync(selfHostedOrganization, license); await UpdateOrganizationAsync(selfHostedOrganization, license); } diff --git a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs index 83789be2f3..7ccbacc938 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationLicense.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationLicense.cs @@ -153,6 +153,7 @@ public class OrganizationLicense : ILicense public LicenseType? LicenseType { get; set; } public bool UseOrganizationDomains { get; set; } public bool UseAdminSponsoredFamilies { get; set; } + public bool UseAutomaticUserConfirmation { get; set; } public string Hash { get; set; } public string Signature { get; set; } public string Token { get; set; } @@ -226,7 +227,8 @@ public class OrganizationLicense : ILicense // any new fields added need to be added here so that they're ignored !p.Name.Equals(nameof(UseRiskInsights)) && !p.Name.Equals(nameof(UseAdminSponsoredFamilies)) && - !p.Name.Equals(nameof(UseOrganizationDomains))) + !p.Name.Equals(nameof(UseOrganizationDomains)) && + !p.Name.Equals(nameof(UseAutomaticUserConfirmation))) .OrderBy(p => p.Name) .Select(p => $"{p.Name}:{Core.Utilities.CoreHelpers.FormatLicenseSignatureValue(p.GetValue(this, null))}") .Aggregate((c, n) => $"{c}|{n}"); @@ -421,6 +423,7 @@ public class OrganizationLicense : ILicense var smServiceAccounts = claimsPrincipal.GetValue(nameof(SmServiceAccounts)); var useAdminSponsoredFamilies = claimsPrincipal.GetValue(nameof(UseAdminSponsoredFamilies)); var useOrganizationDomains = claimsPrincipal.GetValue(nameof(UseOrganizationDomains)); + var useAutomaticUserConfirmation = claimsPrincipal.GetValue(nameof(UseAutomaticUserConfirmation)); return issued <= DateTime.UtcNow && expires >= DateTime.UtcNow && @@ -450,7 +453,8 @@ public class OrganizationLicense : ILicense smSeats == organization.SmSeats && smServiceAccounts == organization.SmServiceAccounts && useAdminSponsoredFamilies == organization.UseAdminSponsoredFamilies && - useOrganizationDomains == organization.UseOrganizationDomains; + useOrganizationDomains == organization.UseOrganizationDomains && + useAutomaticUserConfirmation == organization.UseAutomaticUserConfirmation; } diff --git a/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs b/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs index 2bcd213dbf..fedd0ad78c 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationMetadata.cs @@ -1,28 +1,10 @@ namespace Bit.Core.Billing.Organizations.Models; public record OrganizationMetadata( - bool IsEligibleForSelfHost, - bool IsManaged, bool IsOnSecretsManagerStandalone, - bool IsSubscriptionUnpaid, - bool HasSubscription, - bool HasOpenInvoice, - bool IsSubscriptionCanceled, - DateTime? InvoiceDueDate, - DateTime? InvoiceCreatedDate, - DateTime? SubPeriodEndDate, int OrganizationOccupiedSeats) { public static OrganizationMetadata Default => new OrganizationMetadata( false, - false, - false, - false, - false, - false, - false, - null, - null, - null, 0); } diff --git a/src/Core/Billing/Organizations/Models/OrganizationSale.cs b/src/Core/Billing/Organizations/Models/OrganizationSale.cs index f1f3a636b7..a984d5fe71 100644 --- a/src/Core/Billing/Organizations/Models/OrganizationSale.cs +++ b/src/Core/Billing/Organizations/Models/OrganizationSale.cs @@ -9,7 +9,7 @@ namespace Bit.Core.Billing.Organizations.Models; public class OrganizationSale { - private OrganizationSale() { } + internal OrganizationSale() { } public void Deconstruct( out Organization organization, diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs new file mode 100644 index 0000000000..7781f91960 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPlanChange.cs @@ -0,0 +1,23 @@ +using Bit.Core.Billing.Enums; + +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionPlanChange +{ + public ProductTierType Tier { get; init; } + public PlanCadenceType Cadence { get; init; } + + public PlanType PlanType => + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault + Tier switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => Cadence == PlanCadenceType.Monthly + ? PlanType.TeamsMonthly + : PlanType.TeamsAnnually, + ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly + ? PlanType.EnterpriseMonthly + : PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException("Cannot change an Organization subscription to a tier that isn't Families, Teams or Enterprise.") + }; +} diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs new file mode 100644 index 0000000000..6691d69848 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionPurchase.cs @@ -0,0 +1,39 @@ +using Bit.Core.Billing.Enums; + +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionPurchase +{ + public ProductTierType Tier { get; init; } + public PlanCadenceType Cadence { get; init; } + public required PasswordManagerSelections PasswordManager { get; init; } + public SecretsManagerSelections? SecretsManager { get; init; } + + public PlanType PlanType => + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault + Tier switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => Cadence == PlanCadenceType.Monthly + ? PlanType.TeamsMonthly + : PlanType.TeamsAnnually, + ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly + ? PlanType.EnterpriseMonthly + : PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException("Cannot purchase an Organization subscription that isn't Families, Teams or Enterprise.") + }; + + public record PasswordManagerSelections + { + public int Seats { get; init; } + public int AdditionalStorage { get; init; } + public bool Sponsored { get; init; } + } + + public record SecretsManagerSelections + { + public int Seats { get; init; } + public int AdditionalServiceAccounts { get; init; } + public bool Standalone { get; init; } + } +} diff --git a/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs new file mode 100644 index 0000000000..810f292c81 --- /dev/null +++ b/src/Core/Billing/Organizations/Models/OrganizationSubscriptionUpdate.cs @@ -0,0 +1,19 @@ +namespace Bit.Core.Billing.Organizations.Models; + +public record OrganizationSubscriptionUpdate +{ + public PasswordManagerSelections? PasswordManager { get; init; } + public SecretsManagerSelections? SecretsManager { get; init; } + + public record PasswordManagerSelections + { + public int? Seats { get; init; } + public int? AdditionalStorage { get; init; } + } + + public record SecretsManagerSelections + { + public int? Seats { get; init; } + public int? AdditionalServiceAccounts { get; init; } + } +} diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs new file mode 100644 index 0000000000..493bae2872 --- /dev/null +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationMetadataQuery.cs @@ -0,0 +1,93 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Repositories; +using Bit.Core.Settings; +using Stripe; + +namespace Bit.Core.Billing.Organizations.Queries; + +public interface IGetOrganizationMetadataQuery +{ + Task Run(Organization organization); +} + +public class GetOrganizationMetadataQuery( + IGlobalSettings globalSettings, + IOrganizationRepository organizationRepository, + IPricingClient pricingClient, + ISubscriberService subscriberService) : IGetOrganizationMetadataQuery +{ + public async Task Run(Organization organization) + { + if (globalSettings.SelfHosted) + { + return OrganizationMetadata.Default; + } + + var orgOccupiedSeats = await organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) + { + return OrganizationMetadata.Default with + { + OrganizationOccupiedSeats = orgOccupiedSeats.Total + }; + } + + var customer = await subscriberService.GetCustomer(organization); + + var subscription = await subscriberService.GetSubscription(organization, new SubscriptionGetOptions + { + Expand = ["discounts.coupon.applies_to"] + }); + + if (customer == null || subscription == null) + { + return OrganizationMetadata.Default with + { + OrganizationOccupiedSeats = orgOccupiedSeats.Total + }; + } + + var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription); + + return new OrganizationMetadata( + isOnSecretsManagerStandalone, + orgOccupiedSeats.Total); + } + + private async Task IsOnSecretsManagerStandalone( + Organization organization, + Customer? customer, + Subscription? subscription) + { + if (customer == null || subscription == null) + { + return false; + } + + var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); + + if (!plan.SupportsSecretsManager) + { + return false; + } + + var coupon = subscription.Discounts?.FirstOrDefault(discount => + discount.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone)?.Coupon; + + if (coupon == null) + { + return false; + } + + var subscriptionProductIds = subscription.Items.Data.Select(item => item.Plan.ProductId); + + var couponAppliesTo = coupon.AppliesTo?.Products; + + return subscriptionProductIds.Intersect(couponAppliesTo ?? []).Any(); + } +} diff --git a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs index f33814f1cf..01e520ea41 100644 --- a/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs +++ b/src/Core/Billing/Organizations/Queries/GetOrganizationWarningsQuery.cs @@ -2,11 +2,11 @@ using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; using Bit.Core.Context; using Bit.Core.Services; @@ -30,8 +30,8 @@ public interface IGetOrganizationWarningsQuery public class GetOrganizationWarningsQuery( ICurrentContext currentContext, + IHasPaymentMethodQuery hasPaymentMethodQuery, IProviderRepository providerRepository, - ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService) : IGetOrganizationWarningsQuery { @@ -81,15 +81,7 @@ public class GetOrganizationWarningsQuery( return null; } - var customer = subscription.Customer; - - var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(organization); - - var hasPaymentMethod = - !string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || - !string.IsNullOrEmpty(customer.DefaultSourceId) || - hasUnverifiedBankAccount || - customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); + var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization); if (hasPaymentMethod) { @@ -170,17 +162,23 @@ public class GetOrganizationWarningsQuery( if (subscription is { Status: SubscriptionStatus.Trialing or SubscriptionStatus.Active, - LatestInvoice: null or { Status: InvoiceStatus.Paid } - } && (subscription.CurrentPeriodEnd - now).TotalDays <= 14) + LatestInvoice: null or { Status: InvoiceStatus.Paid }, + Items.Data.Count: > 0 + }) { - return new ResellerRenewalWarning + var currentPeriodEnd = subscription.GetCurrentPeriodEnd(); + + if (currentPeriodEnd != null && (currentPeriodEnd.Value - now).TotalDays <= 14) { - Type = "upcoming", - Upcoming = new ResellerRenewalWarning.UpcomingRenewal + return new ResellerRenewalWarning { - RenewalDate = subscription.CurrentPeriodEnd - } - }; + Type = "upcoming", + Upcoming = new ResellerRenewalWarning.UpcomingRenewal + { + RenewalDate = currentPeriodEnd.Value + } + }; + } } if (subscription is @@ -287,22 +285,4 @@ public class GetOrganizationWarningsQuery( _ => null }; } - - private async Task HasUnverifiedBankAccountAsync( - Organization organization) - { - var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(organization.Id); - - if (string.IsNullOrEmpty(setupIntentId)) - { - return false; - } - - var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions - { - Expand = ["payment_method"] - }); - - return setupIntent.IsUnverifiedBankAccount(); - } } diff --git a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs index ce8a9a877b..b10f04d766 100644 --- a/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs +++ b/src/Core/Billing/Organizations/Services/OrganizationBillingService.cs @@ -6,6 +6,7 @@ using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; @@ -27,6 +28,7 @@ namespace Bit.Core.Billing.Organizations.Services; public class OrganizationBillingService( IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, + IHasPaymentMethodQuery hasPaymentMethodQuery, ILogger logger, IOrganizationRepository organizationRepository, IPricingClient pricingClient, @@ -43,19 +45,14 @@ public class OrganizationBillingService( ? await CreateCustomerAsync(organization, customerSetup, subscriptionSetup.PlanType) : await GetCustomerWhileEnsuringCorrectTaxExemptionAsync(organization, subscriptionSetup); - var subscription = await CreateSubscriptionAsync(organization.Id, customer, subscriptionSetup); + var subscription = await CreateSubscriptionAsync(organization, customer, subscriptionSetup, customerSetup?.Coupon); if (subscription.Status is StripeConstants.SubscriptionStatus.Trialing or StripeConstants.SubscriptionStatus.Active) { organization.Enabled = true; - organization.ExpirationDate = subscription.CurrentPeriodEnd; + organization.ExpirationDate = subscription.GetCurrentPeriodEnd(); + await organizationRepository.ReplaceAsync(organization); } - - organization.Gateway = GatewayType.Stripe; - organization.GatewayCustomerId = customer.Id; - organization.GatewaySubscriptionId = subscription.Id; - - await organizationRepository.ReplaceAsync(organization); } public async Task GetMetadata(Guid organizationId) @@ -72,56 +69,39 @@ public class OrganizationBillingService( return OrganizationMetadata.Default; } - var isEligibleForSelfHost = await IsEligibleForSelfHostAsync(organization); - - var isManaged = organization.Status == OrganizationStatusType.Managed; var orgOccupiedSeats = await organizationRepository.GetOccupiedSeatCountByOrganizationIdAsync(organization.Id); + if (string.IsNullOrWhiteSpace(organization.GatewaySubscriptionId)) { return OrganizationMetadata.Default with { - IsEligibleForSelfHost = isEligibleForSelfHost, - IsManaged = isManaged, OrganizationOccupiedSeats = orgOccupiedSeats.Total }; } - var customer = await subscriberService.GetCustomer(organization, - new CustomerGetOptions { Expand = ["discount.coupon.applies_to"] }); + var customer = await subscriberService.GetCustomer(organization); - var subscription = await subscriberService.GetSubscription(organization); + var subscription = await subscriberService.GetSubscription(organization, new SubscriptionGetOptions + { + Expand = ["discounts.coupon.applies_to"] + }); if (customer == null || subscription == null) { return OrganizationMetadata.Default with { - IsEligibleForSelfHost = isEligibleForSelfHost, - IsManaged = isManaged + OrganizationOccupiedSeats = orgOccupiedSeats.Total }; } var isOnSecretsManagerStandalone = await IsOnSecretsManagerStandalone(organization, customer, subscription); - var invoice = !string.IsNullOrEmpty(subscription.LatestInvoiceId) - ? await stripeAdapter.InvoiceGetAsync(subscription.LatestInvoiceId, new InvoiceGetOptions()) - : null; - return new OrganizationMetadata( - isEligibleForSelfHost, - isManaged, isOnSecretsManagerStandalone, - subscription.Status == StripeConstants.SubscriptionStatus.Unpaid, - true, - invoice?.Status == StripeConstants.InvoiceStatus.Open, - subscription.Status == StripeConstants.SubscriptionStatus.Canceled, - invoice?.DueDate, - invoice?.Created, - subscription.CurrentPeriodEnd, orgOccupiedSeats.Total); } - public async Task - UpdatePaymentMethod( + public async Task UpdatePaymentMethod( Organization organization, TokenizedPaymentSource tokenizedPaymentSource, TaxInformation taxInformation) @@ -209,7 +189,6 @@ public class OrganizationBillingService( var customerCreateOptions = new CustomerCreateOptions { - Coupon = customerSetup.Coupon, Description = organization.DisplayBusinessName(), Email = organization.BillingEmail, Expand = ["tax", "tax_ids"], @@ -272,8 +251,6 @@ public class OrganizationBillingService( ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; - - if (planType.GetProductTier() is not ProductTierType.Free and not ProductTierType.Families && customerSetup.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) { @@ -297,7 +274,7 @@ public class OrganizationBillingService( customerCreateOptions.TaxIdData = [ - new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } + new CustomerTaxIdDataOptions { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } ]; if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) @@ -352,7 +329,13 @@ public class OrganizationBillingService( try { - return await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + var customer = await stripeAdapter.CustomerCreateAsync(customerCreateOptions); + + organization.Gateway = GatewayType.Stripe; + organization.GatewayCustomerId = customer.Id; + await organizationRepository.ReplaceAsync(organization); + + return customer; } catch (StripeException stripeException) when (stripeException.StripeError?.Code == StripeConstants.ErrorCodes.CustomerTaxLocationInvalid) @@ -397,9 +380,10 @@ public class OrganizationBillingService( } private async Task CreateSubscriptionAsync( - Guid organizationId, + Organization organization, Customer customer, - SubscriptionSetup subscriptionSetup) + SubscriptionSetup subscriptionSetup, + string? coupon) { var plan = await pricingClient.GetPlanOrThrow(subscriptionSetup.PlanType); @@ -462,10 +446,11 @@ public class OrganizationBillingService( { CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, Customer = customer.Id, + Discounts = !string.IsNullOrEmpty(coupon) ? [new SubscriptionDiscountOptions { Coupon = coupon }] : null, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - ["organizationId"] = organizationId.ToString(), + ["organizationId"] = organization.Id.ToString(), ["trialInitiationPath"] = !string.IsNullOrEmpty(subscriptionSetup.InitiationPath) && subscriptionSetup.InitiationPath.Contains("trial from marketing website") ? "marketing-initiated" @@ -475,9 +460,11 @@ public class OrganizationBillingService( TrialPeriodDays = subscriptionSetup.SkipTrial ? 0 : plan.TrialPeriodDays }; - // Only set trial_settings.end_behavior.missing_payment_method to "cancel" if there is no payment method - if (string.IsNullOrEmpty(customer.InvoiceSettings?.DefaultPaymentMethodId) && - !customer.Metadata.ContainsKey(BraintreeCustomerIdKey)) + var hasPaymentMethod = await hasPaymentMethodQuery.Run(organization); + + // Only set trial_settings.end_behavior.missing_payment_method to "cancel" + // if there is no payment method AND there's an actual trial period + if (!hasPaymentMethod && subscriptionCreateOptions.TrialPeriodDays > 0) { subscriptionCreateOptions.TrialSettings = new SubscriptionTrialSettingsOptions { @@ -492,7 +479,13 @@ public class OrganizationBillingService( { subscriptionCreateOptions.AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }; } - return await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); + + organization.GatewaySubscriptionId = subscription.Id; + await organizationRepository.ReplaceAsync(organization); + + return subscription; } private async Task GetCustomerWhileEnsuringCorrectTaxExemptionAsync( @@ -534,16 +527,6 @@ public class OrganizationBillingService( return customer; } - private async Task IsEligibleForSelfHostAsync( - Organization organization) - { - var plans = await pricingClient.ListPlans(); - - var eligibleSelfHostPlans = plans.Where(plan => plan.HasSelfHost).Select(plan => plan.Type); - - return eligibleSelfHostPlans.Contains(organization.PlanType); - } - private async Task IsOnSecretsManagerStandalone( Organization organization, Customer? customer, @@ -561,16 +544,17 @@ public class OrganizationBillingService( return false; } - var hasCoupon = customer.Discount?.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone; + var coupon = subscription.Discounts?.FirstOrDefault(discount => + discount.Coupon?.Id == StripeConstants.CouponIDs.SecretsManagerStandalone)?.Coupon; - if (!hasCoupon) + if (coupon == null) { return false; } var subscriptionProductIds = subscription.Items.Data.Select(item => item.Plan.ProductId); - var couponAppliesTo = customer.Discount?.Coupon?.AppliesTo?.Products; + var couponAppliesTo = coupon.AppliesTo?.Products; return subscriptionProductIds.Intersect(couponAppliesTo ?? []).Any(); } diff --git a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs index a86f0e3ada..cc07f1b5db 100644 --- a/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs +++ b/src/Core/Billing/Payment/Commands/CreateBitPayInvoiceForCreditCommand.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Clients; using Bit.Core.Entities; using Bit.Core.Settings; @@ -9,6 +10,8 @@ using Microsoft.Extensions.Logging; namespace Bit.Core.Billing.Payment.Commands; +using static BitPayConstants; + public interface ICreateBitPayInvoiceForCreditCommand { Task> Run( @@ -31,6 +34,8 @@ public class CreateBitPayInvoiceForCreditCommand( { var (name, email, posData) = GetSubscriberInformation(subscriber); + var notificationUrl = $"{globalSettings.BitPay.NotificationUrl}?key={globalSettings.BitPay.WebhookKey}"; + var invoice = new Invoice { Buyer = new Buyer { Email = email, Name = name }, @@ -38,7 +43,7 @@ public class CreateBitPayInvoiceForCreditCommand( ExtendedNotifications = true, FullNotifications = true, ItemDesc = "Bitwarden", - NotificationUrl = globalSettings.BitPay.NotificationUrl, + NotificationUrl = notificationUrl, PosData = posData, Price = Convert.ToDouble(amount), RedirectUrl = redirectUrl @@ -51,10 +56,10 @@ public class CreateBitPayInvoiceForCreditCommand( private static (string? Name, string? Email, string POSData) GetSubscriberInformation( ISubscriber subscriber) => subscriber switch { - User user => (user.Email, user.Email, $"userId:{user.Id},accountCredit:1"), + User user => (user.Email, user.Email, $"userId:{user.Id},{PosDataKeys.AccountCredit}"), Organization organization => (organization.Name, organization.BillingEmail, - $"organizationId:{organization.Id},accountCredit:1"), - Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},accountCredit:1"), + $"organizationId:{organization.Id},{PosDataKeys.AccountCredit}"), + Provider provider => (provider.Name, provider.BillingEmail, $"providerId:{provider.Id},{PosDataKeys.AccountCredit}"), _ => throw new ArgumentOutOfRangeException(nameof(subscriber)) }; } diff --git a/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs b/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs new file mode 100644 index 0000000000..5e8ec0484c --- /dev/null +++ b/src/Core/Billing/Payment/Models/NonTokenizedPaymentMethod.cs @@ -0,0 +1,11 @@ +namespace Bit.Core.Billing.Payment.Models; + +public record NonTokenizedPaymentMethod +{ + public NonTokenizablePaymentMethodType Type { get; set; } +} + +public enum NonTokenizablePaymentMethodType +{ + AccountCredit, +} diff --git a/src/Core/Billing/Payment/Models/PaymentMethod.cs b/src/Core/Billing/Payment/Models/PaymentMethod.cs new file mode 100644 index 0000000000..b0733da414 --- /dev/null +++ b/src/Core/Billing/Payment/Models/PaymentMethod.cs @@ -0,0 +1,71 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using OneOf; + +namespace Bit.Core.Billing.Payment.Models; + +[JsonConverter(typeof(PaymentMethodJsonConverter))] +public class PaymentMethod(OneOf input) + : OneOfBase(input) +{ + public static implicit operator PaymentMethod(TokenizedPaymentMethod tokenized) => new(tokenized); + public static implicit operator PaymentMethod(NonTokenizedPaymentMethod nonTokenized) => new(nonTokenized); + public bool IsTokenized => IsT0; + public TokenizedPaymentMethod AsTokenized => AsT0; + public bool IsNonTokenized => IsT1; + public NonTokenizedPaymentMethod AsNonTokenized => AsT1; +} + +internal class PaymentMethodJsonConverter : JsonConverter +{ + public override PaymentMethod Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + + if (!element.TryGetProperty("type", out var typeProperty)) + { + throw new JsonException("PaymentMethod requires a 'type' property"); + } + + var type = typeProperty.GetString(); + + + if (Enum.TryParse(type, true, out var tokenizedType) && + Enum.IsDefined(typeof(TokenizablePaymentMethodType), tokenizedType)) + { + var token = element.TryGetProperty("token", out var tokenProperty) ? tokenProperty.GetString() : null; + if (string.IsNullOrEmpty(token)) + { + throw new JsonException("TokenizedPaymentMethod requires a 'token' property"); + } + + return new TokenizedPaymentMethod { Type = tokenizedType, Token = token }; + } + + if (Enum.TryParse(type, true, out var nonTokenizedType) && + Enum.IsDefined(typeof(NonTokenizablePaymentMethodType), nonTokenizedType)) + { + return new NonTokenizedPaymentMethod { Type = nonTokenizedType }; + } + + throw new JsonException($"Unknown payment method type: {type}"); + } + + public override void Write(Utf8JsonWriter writer, PaymentMethod value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + value.Switch( + tokenized => + { + writer.WriteString("type", + tokenized.Type.ToString().ToLowerInvariant() + ); + writer.WriteString("token", tokenized.Token); + }, + nonTokenized => { writer.WriteString("type", nonTokenized.Type.ToString().ToLowerInvariant()); } + ); + + writer.WriteEndObject(); + } +} diff --git a/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs new file mode 100644 index 0000000000..ec77ee0712 --- /dev/null +++ b/src/Core/Billing/Payment/Queries/HasPaymentMethodQuery.cs @@ -0,0 +1,58 @@ +using Bit.Core.Billing.Caches; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Services; +using Stripe; + +namespace Bit.Core.Billing.Payment.Queries; + +using static StripeConstants; + +public interface IHasPaymentMethodQuery +{ + Task Run(ISubscriber subscriber); +} + +public class HasPaymentMethodQuery( + ISetupIntentCache setupIntentCache, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService) : IHasPaymentMethodQuery +{ + public async Task Run(ISubscriber subscriber) + { + var hasUnverifiedBankAccount = await HasUnverifiedBankAccountAsync(subscriber); + + var customer = await subscriberService.GetCustomer(subscriber); + + if (customer == null) + { + return hasUnverifiedBankAccount; + } + + return + !string.IsNullOrEmpty(customer.InvoiceSettings.DefaultPaymentMethodId) || + !string.IsNullOrEmpty(customer.DefaultSourceId) || + hasUnverifiedBankAccount || + customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); + } + + private async Task HasUnverifiedBankAccountAsync( + ISubscriber subscriber) + { + var setupIntentId = await setupIntentCache.GetSetupIntentIdForSubscriber(subscriber.Id); + + if (string.IsNullOrEmpty(setupIntentId)) + { + return false; + } + + var setupIntent = await stripeAdapter.SetupIntentGet(setupIntentId, new SetupIntentGetOptions + { + Expand = ["payment_method"] + }); + + return setupIntent.IsUnverifiedBankAccount(); + } +} diff --git a/src/Core/Billing/Payment/Registrations.cs b/src/Core/Billing/Payment/Registrations.cs index 478673d2fc..89d3778ccd 100644 --- a/src/Core/Billing/Payment/Registrations.cs +++ b/src/Core/Billing/Payment/Registrations.cs @@ -19,5 +19,6 @@ public static class Registrations services.AddTransient(); services.AddTransient(); services.AddTransient(); + services.AddTransient(); } } diff --git a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs index 1227cdc034..1f752a007b 100644 --- a/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs +++ b/src/Core/Billing/Premium/Commands/CreatePremiumCloudHostedSubscriptionCommand.cs @@ -1,7 +1,11 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Payment.Queries; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; @@ -14,10 +18,12 @@ using Microsoft.Extensions.Logging; using OneOf.Types; using Stripe; using Customer = Stripe.Customer; +using PaymentMethod = Bit.Core.Billing.Payment.Models.PaymentMethod; using Subscription = Stripe.Subscription; namespace Bit.Core.Billing.Premium.Commands; +using static StripeConstants; using static Utilities; /// @@ -29,14 +35,14 @@ public interface ICreatePremiumCloudHostedSubscriptionCommand /// /// Creates a premium cloud-hosted subscription for the specified user. /// - /// The user to create the premium subscription for. Must not already be a premium user. + /// The user to create the premium subscription for. Must not yet be a premium user. /// The tokenized payment method containing the payment type and token for billing. /// The billing address information required for tax calculation and customer creation. /// Additional storage in GB beyond the base 1GB included with premium (must be >= 0). /// A billing command result indicating success or failure with appropriate error details. Task> Run( User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress, short additionalStorageGb); } @@ -49,7 +55,10 @@ public class CreatePremiumCloudHostedSubscriptionCommand( ISubscriberService subscriberService, IUserService userService, IPushNotificationService pushNotificationService, - ILogger logger) + ILogger logger, + IPricingClient pricingClient, + IHasPaymentMethodQuery hasPaymentMethodQuery, + IUpdatePaymentMethodCommand updatePaymentMethodCommand) : BaseBillingCommand(logger), ICreatePremiumCloudHostedSubscriptionCommand { private static readonly List _expand = ["tax"]; @@ -57,7 +66,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( public Task> Run( User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress, short additionalStorageGb) => HandleAsync(async () => { @@ -71,26 +80,62 @@ public class CreatePremiumCloudHostedSubscriptionCommand( return new BadRequest("Additional storage must be greater than 0."); } - var customer = string.IsNullOrEmpty(user.GatewayCustomerId) - ? await CreateCustomerAsync(user, paymentMethod, billingAddress) - : await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + Customer? customer; + + /* + * For a new customer purchasing a new subscription, we attach the payment method while creating the customer. + */ + if (string.IsNullOrEmpty(user.GatewayCustomerId)) + { + customer = await CreateCustomerAsync(user, paymentMethod, billingAddress); + } + /* + * An existing customer without a payment method starting a new subscription indicates a user who previously + * purchased account credit but chose to use a tokenizable payment method to pay for the subscription. In this case, + * we need to add the payment method to their customer first. If the incoming payment method is account credit, + * we can just go straight to fetching the customer since there's no payment method to apply. + */ + else if (paymentMethod.IsTokenized && !await hasPaymentMethodQuery.Run(user)) + { + await updatePaymentMethodCommand.Run(user, paymentMethod.AsTokenized, billingAddress); + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } + else + { + customer = await subscriberService.GetCustomerOrThrow(user, new CustomerGetOptions { Expand = _expand }); + } customer = await ReconcileBillingLocationAsync(customer, billingAddress); var subscription = await CreateSubscriptionAsync(user.Id, customer, additionalStorageGb > 0 ? additionalStorageGb : null); - switch (paymentMethod) - { - case { Type: TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Incomplete: - case { Type: not TokenizablePaymentMethodType.PayPal } - when subscription.Status == StripeConstants.SubscriptionStatus.Active: + paymentMethod.Switch( + tokenized => + { + // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault + switch (tokenized) { - user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; - break; + case { Type: TokenizablePaymentMethodType.PayPal } + when subscription.Status == SubscriptionStatus.Incomplete: + case { Type: not TokenizablePaymentMethodType.PayPal } + when subscription.Status == SubscriptionStatus.Active: + { + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + break; + } } - } + }, + _ => + { + if (subscription.Status != SubscriptionStatus.Active) + { + return; + } + + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + }); user.Gateway = GatewayType.Stripe; user.GatewayCustomerId = customer.Id; @@ -106,9 +151,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }); private async Task CreateCustomerAsync(User user, - TokenizedPaymentMethod paymentMethod, + PaymentMethod paymentMethod, BillingAddress billingAddress) { + if (paymentMethod.IsNonTokenized) + { + _logger.LogError("Cannot create customer for user ({UserID}) using non-tokenized payment method. The customer should already exist", user.Id); + throw new BillingException(); + } + var subscriberName = user.SubscriberName(); var customerCreateOptions = new CustomerCreateOptions { @@ -139,24 +190,25 @@ public class CreatePremiumCloudHostedSubscriptionCommand( }, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, - [StripeConstants.MetadataKeys.UserId] = user.Id.ToString() + [MetadataKeys.Region] = globalSettings.BaseServiceUri.CloudRegion, + [MetadataKeys.UserId] = user.Id.ToString() }, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; var braintreeCustomerId = ""; - // ReSharper disable once SwitchStatementHandlesSomeKnownEnumValuesWithDefault - switch (paymentMethod.Type) + // We have checked that the payment method is tokenized, so we can safely cast it. + var tokenizedPaymentMethod = paymentMethod.AsTokenized; + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { var setupIntent = - (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = paymentMethod.Token })) + (await stripeAdapter.SetupIntentList(new SetupIntentListOptions { PaymentMethod = tokenizedPaymentMethod.Token })) .FirstOrDefault(); if (setupIntent == null) @@ -170,19 +222,19 @@ public class CreatePremiumCloudHostedSubscriptionCommand( } case TokenizablePaymentMethodType.Card: { - customerCreateOptions.PaymentMethod = paymentMethod.Token; - customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = paymentMethod.Token; + customerCreateOptions.PaymentMethod = tokenizedPaymentMethod.Token; + customerCreateOptions.InvoiceSettings.DefaultPaymentMethod = tokenizedPaymentMethod.Token; break; } case TokenizablePaymentMethodType.PayPal: { - braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, paymentMethod.Token); + braintreeCustomerId = await subscriberService.CreateBraintreeCustomer(user, tokenizedPaymentMethod.Token); customerCreateOptions.Metadata[BraintreeCustomerIdKey] = braintreeCustomerId; break; } default: { - _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, paymentMethod.Type.ToString()); + _logger.LogError("Cannot create customer for user ({UserID}) using payment method type ({PaymentMethodType}) as it is not supported", user.Id, tokenizedPaymentMethod.Type.ToString()); throw new BillingException(); } } @@ -200,7 +252,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( async Task Revert() { // ReSharper disable once SwitchStatementMissingSomeEnumCasesNoDefault - switch (paymentMethod.Type) + switch (tokenizedPaymentMethod.Type) { case TokenizablePaymentMethodType.BankAccount: { @@ -243,7 +295,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( Expand = _expand, Tax = new CustomerTaxOptions { - ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately + ValidateLocation = ValidateTaxLocationTiming.Immediately } }; return await stripeAdapter.CustomerUpdateAsync(customer.Id, options); @@ -254,11 +306,13 @@ public class CreatePremiumCloudHostedSubscriptionCommand( Customer customer, int? storage) { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + var subscriptionItemOptionsList = new List { new () { - Price = StripeConstants.Prices.PremiumAnnually, + Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } }; @@ -267,7 +321,7 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = StripeConstants.Prices.StoragePlanPersonal, + Price = premiumPlan.Storage.StripePriceId, Quantity = storage }); } @@ -280,15 +334,15 @@ public class CreatePremiumCloudHostedSubscriptionCommand( { Enabled = true }, - CollectionMethod = StripeConstants.CollectionMethod.ChargeAutomatically, + CollectionMethod = CollectionMethod.ChargeAutomatically, Customer = customer.Id, Items = subscriptionItemOptionsList, Metadata = new Dictionary { - [StripeConstants.MetadataKeys.UserId] = userId.ToString() + [MetadataKeys.UserId] = userId.ToString() }, PaymentBehavior = usingPayPal - ? StripeConstants.PaymentBehavior.DefaultIncomplete + ? PaymentBehavior.DefaultIncomplete : null, OffSession = true }; diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs new file mode 100644 index 0000000000..5f09b8b77b --- /dev/null +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumTaxCommand.cs @@ -0,0 +1,66 @@ +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Services; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Premium.Commands; + +public interface IPreviewPremiumTaxCommand +{ + Task> Run( + int additionalStorage, + BillingAddress billingAddress); +} + +public class PreviewPremiumTaxCommand( + ILogger logger, + IPricingClient pricingClient, + IStripeAdapter stripeAdapter) : BaseBillingCommand(logger), IPreviewPremiumTaxCommand +{ + public Task> Run( + int additionalStorage, + BillingAddress billingAddress) + => HandleAsync<(decimal, decimal)>(async () => + { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }, + Currency = "usd", + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = + [ + new InvoiceSubscriptionDetailsItemOptions { Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } + ] + } + }; + + if (additionalStorage > 0) + { + options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions + { + Price = premiumPlan.Storage.StripePriceId, + Quantity = additionalStorage + }); + } + + var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); + return GetAmounts(invoice); + }); + + private static (decimal, decimal) GetAmounts(Invoice invoice) => ( + Convert.ToDecimal(invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100, + Convert.ToDecimal(invoice.Total) / 100); +} diff --git a/src/Core/Billing/Pricing/IPricingClient.cs b/src/Core/Billing/Pricing/IPricingClient.cs index bc3f142dda..18588ae432 100644 --- a/src/Core/Billing/Pricing/IPricingClient.cs +++ b/src/Core/Billing/Pricing/IPricingClient.cs @@ -3,12 +3,14 @@ using Bit.Core.Exceptions; using Bit.Core.Models.StaticStore; using Bit.Core.Utilities; -#nullable enable - namespace Bit.Core.Billing.Pricing; +using OrganizationPlan = Plan; +using PremiumPlan = Premium.Plan; + public interface IPricingClient { + // TODO: Rename with Organization focus. /// /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . @@ -16,8 +18,9 @@ public interface IPricingClient /// The type of plan to retrieve. /// A Bitwarden record or null in the case the plan could not be found or the method was executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task GetPlan(PlanType planType); + Task GetPlan(PlanType planType); + // TODO: Rename with Organization focus. /// /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . @@ -26,13 +29,17 @@ public interface IPricingClient /// A Bitwarden record. /// Thrown when the for the provided could not be found or the method was executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task GetPlanOrThrow(PlanType planType); + Task GetPlanOrThrow(PlanType planType); + // TODO: Rename with Organization focus. /// /// Retrieve all the Bitwarden plans. If the feature flag 'use-pricing-service' is enabled, /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . /// /// A list of Bitwarden records or an empty list in the case the method is executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. - Task> ListPlans(); + Task> ListPlans(); + + Task GetAvailablePremiumPlan(); + Task> ListPremiumPlans(); } diff --git a/src/Core/Billing/Pricing/Models/Feature.cs b/src/Core/Billing/Pricing/Organizations/Feature.cs similarity index 69% rename from src/Core/Billing/Pricing/Models/Feature.cs rename to src/Core/Billing/Pricing/Organizations/Feature.cs index ea9da5217d..df10d2bcf8 100644 --- a/src/Core/Billing/Pricing/Models/Feature.cs +++ b/src/Core/Billing/Pricing/Organizations/Feature.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; public class Feature { diff --git a/src/Core/Billing/Pricing/Models/Plan.cs b/src/Core/Billing/Pricing/Organizations/Plan.cs similarity index 94% rename from src/Core/Billing/Pricing/Models/Plan.cs rename to src/Core/Billing/Pricing/Organizations/Plan.cs index 5b4296474b..c533c271cb 100644 --- a/src/Core/Billing/Pricing/Models/Plan.cs +++ b/src/Core/Billing/Pricing/Organizations/Plan.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; public class Plan { diff --git a/src/Core/Billing/Pricing/PlanAdapter.cs b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs similarity index 92% rename from src/Core/Billing/Pricing/PlanAdapter.cs rename to src/Core/Billing/Pricing/Organizations/PlanAdapter.cs index 560987b891..37dc63cb47 100644 --- a/src/Core/Billing/Pricing/PlanAdapter.cs +++ b/src/Core/Billing/Pricing/Organizations/PlanAdapter.cs @@ -1,8 +1,6 @@ using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing.Models; -using Plan = Bit.Core.Billing.Pricing.Models.Plan; -namespace Bit.Core.Billing.Pricing; +namespace Bit.Core.Billing.Pricing.Organizations; public record PlanAdapter : Core.Models.StaticStore.Plan { @@ -60,6 +58,7 @@ public record PlanAdapter : Core.Models.StaticStore.Plan "enterprise-monthly-2020" => PlanType.EnterpriseMonthly2020, "enterprise-monthly-2023" => PlanType.EnterpriseMonthly2023, "families" => PlanType.FamiliesAnnually, + "families-2025" => PlanType.FamiliesAnnually2025, "families-2019" => PlanType.FamiliesAnnually2019, "free" => PlanType.Free, "teams-annually" => PlanType.TeamsAnnually, @@ -79,7 +78,7 @@ public record PlanAdapter : Core.Models.StaticStore.Plan => planType switch { PlanType.Free => ProductTierType.Free, - PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2019 => ProductTierType.Families, + PlanType.FamiliesAnnually or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually2019 => ProductTierType.Families, PlanType.TeamsStarter or PlanType.TeamsStarter2023 => ProductTierType.TeamsStarter, _ when planType.ToString().Contains("Teams") => ProductTierType.Teams, _ when planType.ToString().Contains("Enterprise") => ProductTierType.Enterprise, @@ -105,6 +104,14 @@ public record PlanAdapter : Core.Models.StaticStore.Plan var additionalStoragePricePerGb = plan.Storage?.Price ?? 0; var stripeStoragePlanId = plan.Storage?.StripePriceId; short? maxCollections = plan.AdditionalData.TryGetValue("passwordManager.maxCollections", out var value) ? short.Parse(value) : null; + var stripePremiumAccessPlanId = + plan.AdditionalData.TryGetValue("premiumAccessAddOnPriceId", out var premiumAccessAddOnPriceIdValue) + ? premiumAccessAddOnPriceIdValue + : null; + var premiumAccessOptionPrice = + plan.AdditionalData.TryGetValue("premiumAccessAddOnPriceAmount", out var premiumAccessAddOnPriceAmountValue) + ? decimal.Parse(premiumAccessAddOnPriceAmountValue) + : 0; return new PasswordManagerPlanFeatures { @@ -122,7 +129,9 @@ public record PlanAdapter : Core.Models.StaticStore.Plan HasAdditionalStorageOption = hasAdditionalStorageOption, AdditionalStoragePricePerGb = additionalStoragePricePerGb, StripeStoragePlanId = stripeStoragePlanId, - MaxCollections = maxCollections + MaxCollections = maxCollections, + StripePremiumAccessPlanId = stripePremiumAccessPlanId, + PremiumAccessOptionPrice = premiumAccessOptionPrice }; } diff --git a/src/Core/Billing/Pricing/Models/Purchasable.cs b/src/Core/Billing/Pricing/Organizations/Purchasable.cs similarity index 99% rename from src/Core/Billing/Pricing/Models/Purchasable.cs rename to src/Core/Billing/Pricing/Organizations/Purchasable.cs index 7cb4ee00c1..f6704394f7 100644 --- a/src/Core/Billing/Pricing/Models/Purchasable.cs +++ b/src/Core/Billing/Pricing/Organizations/Purchasable.cs @@ -2,7 +2,7 @@ using System.Text.Json.Serialization; using OneOf; -namespace Bit.Core.Billing.Pricing.Models; +namespace Bit.Core.Billing.Pricing.Organizations; [JsonConverter(typeof(PurchasableJsonConverter))] public class Purchasable(OneOf input) : OneOfBase(input) diff --git a/src/Core/Billing/Pricing/Premium/Plan.cs b/src/Core/Billing/Pricing/Premium/Plan.cs new file mode 100644 index 0000000000..f377157363 --- /dev/null +++ b/src/Core/Billing/Pricing/Premium/Plan.cs @@ -0,0 +1,10 @@ +namespace Bit.Core.Billing.Pricing.Premium; + +public class Plan +{ + public string Name { get; init; } = null!; + public int? LegacyYear { get; init; } + public bool Available { get; init; } + public Purchasable Seat { get; init; } = null!; + public Purchasable Storage { get; init; } = null!; +} diff --git a/src/Core/Billing/Pricing/Premium/Purchasable.cs b/src/Core/Billing/Pricing/Premium/Purchasable.cs new file mode 100644 index 0000000000..633eb2e8aa --- /dev/null +++ b/src/Core/Billing/Pricing/Premium/Purchasable.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Pricing.Premium; + +public class Purchasable +{ + public string StripePriceId { get; init; } = null!; + public decimal Price { get; init; } +} diff --git a/src/Core/Billing/Pricing/PricingClient.cs b/src/Core/Billing/Pricing/PricingClient.cs index a3db8ce07f..1ec44c6496 100644 --- a/src/Core/Billing/Pricing/PricingClient.cs +++ b/src/Core/Billing/Pricing/PricingClient.cs @@ -1,24 +1,27 @@ using System.Net; using System.Net.Http.Json; +using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing.Organizations; using Bit.Core.Exceptions; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Microsoft.Extensions.Logging; -using Plan = Bit.Core.Models.StaticStore.Plan; - -#nullable enable namespace Bit.Core.Billing.Pricing; +using OrganizationPlan = Bit.Core.Models.StaticStore.Plan; +using PremiumPlan = Premium.Plan; +using Purchasable = Premium.Purchasable; + public class PricingClient( IFeatureService featureService, GlobalSettings globalSettings, HttpClient httpClient, ILogger logger) : IPricingClient { - public async Task GetPlan(PlanType planType) + public async Task GetPlan(PlanType planType) { if (globalSettings.SelfHosted) { @@ -40,16 +43,14 @@ public class PricingClient( return null; } - var response = await httpClient.GetAsync($"plans/lookup/{lookupKey}"); + var response = await httpClient.GetAsync($"plans/organization/{lookupKey}"); if (response.IsSuccessStatusCode) { - var plan = await response.Content.ReadFromJsonAsync(); - if (plan == null) - { - throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); - } - return new PlanAdapter(plan); + var plan = await response.Content.ReadFromJsonAsync(); + return plan == null + ? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null") + : new PlanAdapter(PreProcessFamiliesPreMigrationPlan(plan)); } if (response.StatusCode == HttpStatusCode.NotFound) @@ -62,19 +63,14 @@ public class PricingClient( message: $"Request to the Pricing Service failed with status code {response.StatusCode}"); } - public async Task GetPlanOrThrow(PlanType planType) + public async Task GetPlanOrThrow(PlanType planType) { var plan = await GetPlan(planType); - if (plan == null) - { - throw new NotFoundException(); - } - - return plan; + return plan ?? throw new NotFoundException($"Could not find plan for type {planType}"); } - public async Task> ListPlans() + public async Task> ListPlans() { if (globalSettings.SelfHosted) { @@ -88,23 +84,58 @@ public class PricingClient( return StaticStore.Plans.ToList(); } - var response = await httpClient.GetAsync("plans"); + var response = await httpClient.GetAsync("plans/organization"); if (response.IsSuccessStatusCode) { - var plans = await response.Content.ReadFromJsonAsync>(); - if (plans == null) - { - throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); - } - return plans.Select(Plan (plan) => new PlanAdapter(plan)).ToList(); + var plans = await response.Content.ReadFromJsonAsync>(); + return plans == null + ? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null") + : plans.Select(OrganizationPlan (plan) => new PlanAdapter(PreProcessFamiliesPreMigrationPlan(plan))).ToList(); } throw new BillingException( message: $"Request to the Pricing Service failed with status {response.StatusCode}"); } - private static string? GetLookupKey(PlanType planType) + public async Task GetAvailablePremiumPlan() + { + var premiumPlans = await ListPremiumPlans(); + + var availablePlan = premiumPlans.FirstOrDefault(premiumPlan => premiumPlan.Available); + + return availablePlan ?? throw new NotFoundException("Could not find available premium plan"); + } + + public async Task> ListPremiumPlans() + { + if (globalSettings.SelfHosted) + { + return []; + } + + var usePricingService = featureService.IsEnabled(FeatureFlagKeys.UsePricingService); + var fetchPremiumPriceFromPricingService = + featureService.IsEnabled(FeatureFlagKeys.PM26793_FetchPremiumPriceFromPricingService); + + if (!usePricingService || !fetchPremiumPriceFromPricingService) + { + return [CurrentPremiumPlan]; + } + + var response = await httpClient.GetAsync("plans/premium"); + + if (response.IsSuccessStatusCode) + { + var plans = await response.Content.ReadFromJsonAsync>(); + return plans ?? throw new BillingException(message: "Deserialization of Pricing Service response resulted in null"); + } + + throw new BillingException( + message: $"Request to the Pricing Service failed with status {response.StatusCode}"); + } + + private string? GetLookupKey(PlanType planType) => planType switch { PlanType.EnterpriseAnnually => "enterprise-annually", @@ -116,6 +147,10 @@ public class PricingClient( PlanType.EnterpriseMonthly2020 => "enterprise-monthly-2020", PlanType.EnterpriseMonthly2023 => "enterprise-monthly-2023", PlanType.FamiliesAnnually => "families", + PlanType.FamiliesAnnually2025 => + featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3) + ? "families-2025" + : "families", PlanType.FamiliesAnnually2019 => "families-2019", PlanType.Free => "free", PlanType.TeamsAnnually => "teams-annually", @@ -130,4 +165,27 @@ public class PricingClient( PlanType.TeamsStarter2023 => "teams-starter-2023", _ => null }; + + /// + /// Safeguard used until the feature flag is enabled. Pricing service will return the + /// 2025PreMigration plan with "families" lookup key. When that is detected and the FF + /// is still disabled, set the lookup key to families-2025 so PlanAdapter will assign + /// the correct plan. + /// + /// The plan to preprocess + private Plan PreProcessFamiliesPreMigrationPlan(Plan plan) + { + if (plan.LookupKey == "families" && !featureService.IsEnabled(FeatureFlagKeys.PM26462_Milestone_3)) + plan.LookupKey = "families-2025"; + return plan; + } + + private static PremiumPlan CurrentPremiumPlan => new() + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Purchasable { Price = 10M, StripePriceId = StripeConstants.Prices.PremiumAnnually }, + Storage = new Purchasable { Price = 4M, StripePriceId = StripeConstants.Prices.StoragePlanPersonal } + }; } diff --git a/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs b/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs deleted file mode 100644 index 65fd7726f8..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ClientMigrationTracker.cs +++ /dev/null @@ -1,26 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public enum ClientMigrationProgress -{ - Started = 1, - MigrationRecordCreated = 2, - SubscriptionEnded = 3, - Completed = 4, - - Reversing = 5, - ResetOrganization = 6, - RecreatedSubscription = 7, - RemovedMigrationRecord = 8, - Reversed = 9 -} - -public class ClientMigrationTracker -{ - public Guid ProviderId { get; set; } - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public ClientMigrationProgress Progress { get; set; } = ClientMigrationProgress.Started; -} diff --git a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs b/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs deleted file mode 100644 index 78a2631999..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationResult.cs +++ /dev/null @@ -1,48 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Billing.Providers.Entities; - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public class ProviderMigrationResult -{ - public Guid ProviderId { get; set; } - public string ProviderName { get; set; } - public string Result { get; set; } - public List Clients { get; set; } -} - -public class ClientMigrationResult -{ - public Guid OrganizationId { get; set; } - public string OrganizationName { get; set; } - public string Result { get; set; } - public ClientPreviousState PreviousState { get; set; } -} - -public class ClientPreviousState -{ - public ClientPreviousState() { } - - public ClientPreviousState(ClientOrganizationMigrationRecord migrationRecord) - { - PlanType = migrationRecord.PlanType.ToString(); - Seats = migrationRecord.Seats; - MaxStorageGb = migrationRecord.MaxStorageGb; - GatewayCustomerId = migrationRecord.GatewayCustomerId; - GatewaySubscriptionId = migrationRecord.GatewaySubscriptionId; - ExpirationDate = migrationRecord.ExpirationDate; - MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats; - Status = migrationRecord.Status.ToString(); - } - - public string PlanType { get; set; } - public int Seats { get; set; } - public short? MaxStorageGb { get; set; } - public string GatewayCustomerId { get; set; } = null!; - public string GatewaySubscriptionId { get; set; } = null!; - public DateTime? ExpirationDate { get; set; } - public int? MaxAutoscaleSeats { get; set; } - public string Status { get; set; } -} diff --git a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs b/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs deleted file mode 100644 index ba39feab2d..0000000000 --- a/src/Core/Billing/Providers/Migration/Models/ProviderMigrationTracker.cs +++ /dev/null @@ -1,25 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Billing.Providers.Migration.Models; - -public enum ProviderMigrationProgress -{ - Started = 1, - NoClients = 2, - ClientsMigrated = 3, - TeamsPlanConfigured = 4, - EnterprisePlanConfigured = 5, - CustomerSetup = 6, - SubscriptionSetup = 7, - CreditApplied = 8, - Completed = 9, -} - -public class ProviderMigrationTracker -{ - public Guid ProviderId { get; set; } - public string ProviderName { get; set; } - public List OrganizationIds { get; set; } - public ProviderMigrationProgress Progress { get; set; } = ProviderMigrationProgress.Started; -} diff --git a/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs b/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs deleted file mode 100644 index 1061c82888..0000000000 --- a/src/Core/Billing/Providers/Migration/ServiceCollectionExtensions.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Bit.Core.Billing.Providers.Migration.Services; -using Bit.Core.Billing.Providers.Migration.Services.Implementations; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Billing.Providers.Migration; - -public static class ServiceCollectionExtensions -{ - public static void AddProviderMigration(this IServiceCollection services) - { - services.AddTransient(); - services.AddTransient(); - services.AddTransient(); - } -} diff --git a/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs b/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs deleted file mode 100644 index 70649590df..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IMigrationTrackerCache.cs +++ /dev/null @@ -1,17 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Providers.Migration.Models; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IMigrationTrackerCache -{ - Task StartTracker(Provider provider); - Task SetOrganizationIds(Guid providerId, IEnumerable organizationIds); - Task GetTracker(Guid providerId); - Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status); - - Task StartTracker(Guid providerId, Organization organization); - Task GetTracker(Guid providerId, Guid organizationId); - Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status); -} diff --git a/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs b/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs deleted file mode 100644 index a0548277b4..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IOrganizationMigrator.cs +++ /dev/null @@ -1,8 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IOrganizationMigrator -{ - Task Migrate(Guid providerId, Organization organization); -} diff --git a/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs b/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs deleted file mode 100644 index 328c2419f4..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/IProviderMigrator.cs +++ /dev/null @@ -1,10 +0,0 @@ -using Bit.Core.Billing.Providers.Migration.Models; - -namespace Bit.Core.Billing.Providers.Migration.Services; - -public interface IProviderMigrator -{ - Task Migrate(Guid providerId); - - Task GetResult(Guid providerId); -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs deleted file mode 100644 index 1f38b0d111..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/MigrationTrackerDistributedCache.cs +++ /dev/null @@ -1,110 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.Billing.Providers.Migration.Models; -using Microsoft.Extensions.Caching.Distributed; -using Microsoft.Extensions.DependencyInjection; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class MigrationTrackerDistributedCache( - [FromKeyedServices("persistent")] - IDistributedCache distributedCache) : IMigrationTrackerCache -{ - public async Task StartTracker(Provider provider) => - await SetAsync(new ProviderMigrationTracker - { - ProviderId = provider.Id, - ProviderName = provider.Name - }); - - public async Task SetOrganizationIds(Guid providerId, IEnumerable organizationIds) - { - var tracker = await GetAsync(providerId); - - tracker.OrganizationIds = organizationIds.ToList(); - - await SetAsync(tracker); - } - - public Task GetTracker(Guid providerId) => GetAsync(providerId); - - public async Task UpdateTrackingStatus(Guid providerId, ProviderMigrationProgress status) - { - var tracker = await GetAsync(providerId); - - tracker.Progress = status; - - await SetAsync(tracker); - } - - public async Task StartTracker(Guid providerId, Organization organization) => - await SetAsync(new ClientMigrationTracker - { - ProviderId = providerId, - OrganizationId = organization.Id, - OrganizationName = organization.Name - }); - - public Task GetTracker(Guid providerId, Guid organizationId) => - GetAsync(providerId, organizationId); - - public async Task UpdateTrackingStatus(Guid providerId, Guid organizationId, ClientMigrationProgress status) - { - var tracker = await GetAsync(providerId, organizationId); - - tracker.Progress = status; - - await SetAsync(tracker); - } - - private static string GetProviderCacheKey(Guid providerId) => $"provider_{providerId}_migration"; - - private static string GetClientCacheKey(Guid providerId, Guid clientId) => - $"provider_{providerId}_client_{clientId}_migration"; - - private async Task GetAsync(Guid providerId) - { - var cacheKey = GetProviderCacheKey(providerId); - - var json = await distributedCache.GetStringAsync(cacheKey); - - return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize(json); - } - - private async Task GetAsync(Guid providerId, Guid organizationId) - { - var cacheKey = GetClientCacheKey(providerId, organizationId); - - var json = await distributedCache.GetStringAsync(cacheKey); - - return string.IsNullOrEmpty(json) ? null : JsonSerializer.Deserialize(json); - } - - private async Task SetAsync(ProviderMigrationTracker tracker) - { - var cacheKey = GetProviderCacheKey(tracker.ProviderId); - - var json = JsonSerializer.Serialize(tracker); - - await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions - { - SlidingExpiration = TimeSpan.FromMinutes(30) - }); - } - - private async Task SetAsync(ClientMigrationTracker tracker) - { - var cacheKey = GetClientCacheKey(tracker.ProviderId, tracker.OrganizationId); - - var json = JsonSerializer.Serialize(tracker); - - await distributedCache.SetStringAsync(cacheKey, json, new DistributedCacheEntryOptions - { - SlidingExpiration = TimeSpan.FromMinutes(30) - }); - } -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs deleted file mode 100644 index 3de49838af..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/OrganizationMigrator.cs +++ /dev/null @@ -1,331 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Providers.Entities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Repositories; -using Bit.Core.Enums; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; -using Plan = Bit.Core.Models.StaticStore.Plan; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class OrganizationMigrator( - IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository, - ILogger logger, - IMigrationTrackerCache migrationTrackerCache, - IOrganizationRepository organizationRepository, - IPricingClient pricingClient, - IStripeAdapter stripeAdapter) : IOrganizationMigrator -{ - private const string _cancellationComment = "Cancelled as part of provider migration to Consolidated Billing"; - - public async Task Migrate(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Starting migration for organization ({OrganizationID})", organization.Id); - - await migrationTrackerCache.StartTracker(providerId, organization); - - await CreateMigrationRecordAsync(providerId, organization); - - await CancelSubscriptionAsync(providerId, organization); - - await UpdateOrganizationAsync(providerId, organization); - } - - #region Steps - - private async Task CreateMigrationRecordAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Creating ClientOrganizationMigrationRecord for organization ({OrganizationID})", organization.Id); - - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord != null) - { - logger.LogInformation( - "CB: ClientOrganizationMigrationRecord already exists for organization ({OrganizationID}), deleting record", - organization.Id); - - await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord); - } - - await clientOrganizationMigrationRecordRepository.CreateAsync(new ClientOrganizationMigrationRecord - { - OrganizationId = organization.Id, - ProviderId = providerId, - PlanType = organization.PlanType, - Seats = organization.Seats ?? 0, - MaxStorageGb = organization.MaxStorageGb, - GatewayCustomerId = organization.GatewayCustomerId!, - GatewaySubscriptionId = organization.GatewaySubscriptionId!, - ExpirationDate = organization.ExpirationDate, - MaxAutoscaleSeats = organization.MaxAutoscaleSeats, - Status = organization.Status - }); - - logger.LogInformation("CB: Created migration record for organization ({OrganizationID})", organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.MigrationRecordCreated); - } - - private async Task CancelSubscriptionAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Cancelling subscription for organization ({OrganizationID})", organization.Id); - - var subscription = await stripeAdapter.SubscriptionGetAsync(organization.GatewaySubscriptionId); - - if (subscription is - { - Status: - StripeConstants.SubscriptionStatus.Active or - StripeConstants.SubscriptionStatus.PastDue or - StripeConstants.SubscriptionStatus.Trialing - }) - { - await stripeAdapter.SubscriptionUpdateAsync(organization.GatewaySubscriptionId, - new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }); - - subscription = await stripeAdapter.SubscriptionCancelAsync(organization.GatewaySubscriptionId, - new SubscriptionCancelOptions - { - CancellationDetails = new SubscriptionCancellationDetailsOptions - { - Comment = _cancellationComment - }, - InvoiceNow = true, - Prorate = true, - Expand = ["latest_invoice", "test_clock"] - }); - - logger.LogInformation("CB: Cancelled subscription for organization ({OrganizationID})", organization.Id); - - var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; - - var trialing = subscription.TrialEnd.HasValue && subscription.TrialEnd.Value > now; - - if (!trialing && subscription is { Status: StripeConstants.SubscriptionStatus.Canceled, CancellationDetails.Comment: _cancellationComment }) - { - var latestInvoice = subscription.LatestInvoice; - - if (latestInvoice.Status == "draft") - { - await stripeAdapter.InvoiceFinalizeInvoiceAsync(latestInvoice.Id, - new InvoiceFinalizeOptions { AutoAdvance = true }); - - logger.LogInformation("CB: Finalized prorated invoice for organization ({OrganizationID})", organization.Id); - } - } - } - else - { - logger.LogInformation( - "CB: Did not need to cancel subscription for organization ({OrganizationID}) as it was inactive", - organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.SubscriptionEnded); - } - - private async Task UpdateOrganizationAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Bringing organization ({OrganizationID}) under provider management", - organization.Id); - - var plan = await pricingClient.GetPlanOrThrow(organization.Plan.Contains("Teams") ? PlanType.TeamsMonthly : PlanType.EnterpriseMonthly); - - ResetOrganizationPlan(organization, plan); - organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; - organization.GatewaySubscriptionId = null; - organization.ExpirationDate = null; - organization.MaxAutoscaleSeats = null; - organization.Status = OrganizationStatusType.Managed; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Brought organization ({OrganizationID}) under provider management", - organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.Completed); - } - - #endregion - - #region Reverse - - private async Task RemoveMigrationRecordAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Removing migration record for organization ({OrganizationID})", organization.Id); - - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord != null) - { - await clientOrganizationMigrationRecordRepository.DeleteAsync(migrationRecord); - - logger.LogInformation( - "CB: Removed migration record for organization ({OrganizationID})", - organization.Id); - } - else - { - logger.LogInformation("CB: Did not remove migration record for organization ({OrganizationID}) as it does not exist", organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, ClientMigrationProgress.Reversed); - } - - private async Task RecreateSubscriptionAsync(Guid providerId, Organization organization) - { - logger.LogInformation("CB: Recreating subscription for organization ({OrganizationID})", organization.Id); - - if (!string.IsNullOrEmpty(organization.GatewaySubscriptionId)) - { - if (string.IsNullOrEmpty(organization.GatewayCustomerId)) - { - logger.LogError( - "CB: Cannot recreate subscription for organization ({OrganizationID}) as it does not have a Stripe customer", - organization.Id); - - throw new Exception(); - } - - var customer = await stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId, - new CustomerGetOptions { Expand = ["default_source", "invoice_settings.default_payment_method"] }); - - var collectionMethod = - customer.DefaultSource != null || - customer.InvoiceSettings?.DefaultPaymentMethod != null || - customer.Metadata.ContainsKey(Utilities.BraintreeCustomerIdKey) - ? StripeConstants.CollectionMethod.ChargeAutomatically - : StripeConstants.CollectionMethod.SendInvoice; - - var plan = await pricingClient.GetPlanOrThrow(organization.PlanType); - - var items = new List - { - new () - { - Price = plan.PasswordManager.StripeSeatPlanId, - Quantity = organization.Seats - } - }; - - if (organization.MaxStorageGb.HasValue && plan.PasswordManager.BaseStorageGb.HasValue && organization.MaxStorageGb.Value > plan.PasswordManager.BaseStorageGb.Value) - { - var additionalStorage = organization.MaxStorageGb.Value - plan.PasswordManager.BaseStorageGb.Value; - - items.Add(new SubscriptionItemOptions - { - Price = plan.PasswordManager.StripeStoragePlanId, - Quantity = additionalStorage - }); - } - - var subscriptionCreateOptions = new SubscriptionCreateOptions - { - AutomaticTax = new SubscriptionAutomaticTaxOptions - { - Enabled = true - }, - Customer = customer.Id, - CollectionMethod = collectionMethod, - DaysUntilDue = collectionMethod == StripeConstants.CollectionMethod.SendInvoice ? 30 : null, - Items = items, - Metadata = new Dictionary - { - [organization.GatewayIdField()] = organization.Id.ToString() - }, - OffSession = true, - ProrationBehavior = StripeConstants.ProrationBehavior.CreateProrations, - TrialPeriodDays = plan.TrialPeriodDays - }; - - var subscription = await stripeAdapter.SubscriptionCreateAsync(subscriptionCreateOptions); - - organization.GatewaySubscriptionId = subscription.Id; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Recreated subscription for organization ({OrganizationID})", organization.Id); - } - else - { - logger.LogInformation( - "CB: Did not recreate subscription for organization ({OrganizationID}) as it already exists", - organization.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.RecreatedSubscription); - } - - private async Task ReverseOrganizationUpdateAsync(Guid providerId, Organization organization) - { - var migrationRecord = await clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id); - - if (migrationRecord == null) - { - logger.LogError( - "CB: Cannot reverse migration for organization ({OrganizationID}) as it does not have a migration record", - organization.Id); - - throw new Exception(); - } - - var plan = await pricingClient.GetPlanOrThrow(migrationRecord.PlanType); - - ResetOrganizationPlan(organization, plan); - organization.MaxStorageGb = migrationRecord.MaxStorageGb; - organization.ExpirationDate = migrationRecord.ExpirationDate; - organization.MaxAutoscaleSeats = migrationRecord.MaxAutoscaleSeats; - organization.Status = migrationRecord.Status; - - await organizationRepository.ReplaceAsync(organization); - - logger.LogInformation("CB: Reversed organization ({OrganizationID}) updates", - organization.Id); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, organization.Id, - ClientMigrationProgress.ResetOrganization); - } - - #endregion - - #region Shared - - private static void ResetOrganizationPlan(Organization organization, Plan plan) - { - organization.Plan = plan.Name; - organization.PlanType = plan.Type; - organization.MaxCollections = plan.PasswordManager.MaxCollections; - organization.MaxStorageGb = plan.PasswordManager.BaseStorageGb; - organization.UsePolicies = plan.HasPolicies; - organization.UseSso = plan.HasSso; - organization.UseOrganizationDomains = plan.HasOrganizationDomains; - organization.UseGroups = plan.HasGroups; - organization.UseEvents = plan.HasEvents; - organization.UseDirectory = plan.HasDirectory; - organization.UseTotp = plan.HasTotp; - organization.Use2fa = plan.Has2fa; - organization.UseApi = plan.HasApi; - organization.UseResetPassword = plan.HasResetPassword; - organization.SelfHost = plan.HasSelfHost; - organization.UsersGetPremium = plan.UsersGetPremium; - organization.UseCustomPermissions = plan.HasCustomPermissions; - organization.UseScim = plan.HasScim; - organization.UseKeyConnector = plan.HasKeyConnector; - } - - #endregion -} diff --git a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs b/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs deleted file mode 100644 index 07a057d40c..0000000000 --- a/src/Core/Billing/Providers/Migration/Services/Implementations/ProviderMigrator.cs +++ /dev/null @@ -1,436 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Entities.Provider; -using Bit.Core.AdminConsole.Enums.Provider; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Models; -using Bit.Core.Billing.Providers.Entities; -using Bit.Core.Billing.Providers.Migration.Models; -using Bit.Core.Billing.Providers.Models; -using Bit.Core.Billing.Providers.Repositories; -using Bit.Core.Billing.Providers.Services; -using Bit.Core.Enums; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -namespace Bit.Core.Billing.Providers.Migration.Services.Implementations; - -public class ProviderMigrator( - IClientOrganizationMigrationRecordRepository clientOrganizationMigrationRecordRepository, - IOrganizationMigrator organizationMigrator, - ILogger logger, - IMigrationTrackerCache migrationTrackerCache, - IOrganizationRepository organizationRepository, - IPaymentService paymentService, - IProviderBillingService providerBillingService, - IProviderOrganizationRepository providerOrganizationRepository, - IProviderRepository providerRepository, - IProviderPlanRepository providerPlanRepository, - IStripeAdapter stripeAdapter) : IProviderMigrator -{ - public async Task Migrate(Guid providerId) - { - var provider = await GetProviderAsync(providerId); - - if (provider == null) - { - return; - } - - logger.LogInformation("CB: Starting migration for provider ({ProviderID})", providerId); - - await migrationTrackerCache.StartTracker(provider); - - var organizations = await GetClientsAsync(provider.Id); - - if (organizations.Count == 0) - { - logger.LogInformation("CB: Skipping migration for provider ({ProviderID}) with no clients", providerId); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.NoClients); - - return; - } - - await MigrateClientsAsync(providerId, organizations); - - await ConfigureTeamsPlanAsync(providerId); - - await ConfigureEnterprisePlanAsync(providerId); - - await SetupCustomerAsync(provider); - - await SetupSubscriptionAsync(provider); - - await ApplyCreditAsync(provider); - - await UpdateProviderAsync(provider); - } - - public async Task GetResult(Guid providerId) - { - var providerTracker = await migrationTrackerCache.GetTracker(providerId); - - if (providerTracker == null) - { - return null; - } - - if (providerTracker.Progress == ProviderMigrationProgress.NoClients) - { - return new ProviderMigrationResult - { - ProviderId = providerTracker.ProviderId, - ProviderName = providerTracker.ProviderName, - Result = providerTracker.Progress.ToString() - }; - } - - var clientTrackers = await Task.WhenAll(providerTracker.OrganizationIds.Select(organizationId => - migrationTrackerCache.GetTracker(providerId, organizationId))); - - var migrationRecordLookup = new Dictionary(); - - foreach (var clientTracker in clientTrackers) - { - var migrationRecord = - await clientOrganizationMigrationRecordRepository.GetByOrganizationId(clientTracker.OrganizationId); - - migrationRecordLookup.Add(clientTracker.OrganizationId, migrationRecord); - } - - return new ProviderMigrationResult - { - ProviderId = providerTracker.ProviderId, - ProviderName = providerTracker.ProviderName, - Result = providerTracker.Progress.ToString(), - Clients = clientTrackers.Select(tracker => - { - var foundMigrationRecord = migrationRecordLookup.TryGetValue(tracker.OrganizationId, out var migrationRecord); - return new ClientMigrationResult - { - OrganizationId = tracker.OrganizationId, - OrganizationName = tracker.OrganizationName, - Result = tracker.Progress.ToString(), - PreviousState = foundMigrationRecord ? new ClientPreviousState(migrationRecord) : null - }; - }).ToList(), - }; - } - - #region Steps - - private async Task MigrateClientsAsync(Guid providerId, List organizations) - { - logger.LogInformation("CB: Migrating clients for provider ({ProviderID})", providerId); - - var organizationIds = organizations.Select(organization => organization.Id); - - await migrationTrackerCache.SetOrganizationIds(providerId, organizationIds); - - foreach (var organization in organizations) - { - var tracker = await migrationTrackerCache.GetTracker(providerId, organization.Id); - - if (tracker is not { Progress: ClientMigrationProgress.Completed }) - { - await organizationMigrator.Migrate(providerId, organization); - } - } - - logger.LogInformation("CB: Migrated clients for provider ({ProviderID})", providerId); - - await migrationTrackerCache.UpdateTrackingStatus(providerId, - ProviderMigrationProgress.ClientsMigrated); - } - - private async Task ConfigureTeamsPlanAsync(Guid providerId) - { - logger.LogInformation("CB: Configuring Teams plan for provider ({ProviderID})", providerId); - - var organizations = await GetClientsAsync(providerId); - - var teamsSeats = organizations - .Where(IsTeams) - .Sum(client => client.Seats) ?? 0; - - var teamsProviderPlan = (await providerPlanRepository.GetByProviderId(providerId)) - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly); - - if (teamsProviderPlan == null) - { - await providerPlanRepository.CreateAsync(new ProviderPlan - { - ProviderId = providerId, - PlanType = PlanType.TeamsMonthly, - SeatMinimum = teamsSeats, - PurchasedSeats = 0, - AllocatedSeats = teamsSeats - }); - - logger.LogInformation("CB: Created Teams plan for provider ({ProviderID}) with a seat minimum of {Seats}", - providerId, teamsSeats); - } - else - { - logger.LogInformation("CB: Teams plan already exists for provider ({ProviderID}), updating seat minimum", providerId); - - teamsProviderPlan.SeatMinimum = teamsSeats; - teamsProviderPlan.AllocatedSeats = teamsSeats; - - await providerPlanRepository.ReplaceAsync(teamsProviderPlan); - - logger.LogInformation("CB: Updated Teams plan for provider ({ProviderID}) to seat minimum of {Seats}", - providerId, teamsProviderPlan.SeatMinimum); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.TeamsPlanConfigured); - } - - private async Task ConfigureEnterprisePlanAsync(Guid providerId) - { - logger.LogInformation("CB: Configuring Enterprise plan for provider ({ProviderID})", providerId); - - var organizations = await GetClientsAsync(providerId); - - var enterpriseSeats = organizations - .Where(IsEnterprise) - .Sum(client => client.Seats) ?? 0; - - var enterpriseProviderPlan = (await providerPlanRepository.GetByProviderId(providerId)) - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly); - - if (enterpriseProviderPlan == null) - { - await providerPlanRepository.CreateAsync(new ProviderPlan - { - ProviderId = providerId, - PlanType = PlanType.EnterpriseMonthly, - SeatMinimum = enterpriseSeats, - PurchasedSeats = 0, - AllocatedSeats = enterpriseSeats - }); - - logger.LogInformation("CB: Created Enterprise plan for provider ({ProviderID}) with a seat minimum of {Seats}", - providerId, enterpriseSeats); - } - else - { - logger.LogInformation("CB: Enterprise plan already exists for provider ({ProviderID}), updating seat minimum", providerId); - - enterpriseProviderPlan.SeatMinimum = enterpriseSeats; - enterpriseProviderPlan.AllocatedSeats = enterpriseSeats; - - await providerPlanRepository.ReplaceAsync(enterpriseProviderPlan); - - logger.LogInformation("CB: Updated Enterprise plan for provider ({ProviderID}) to seat minimum of {Seats}", - providerId, enterpriseProviderPlan.SeatMinimum); - } - - await migrationTrackerCache.UpdateTrackingStatus(providerId, ProviderMigrationProgress.EnterprisePlanConfigured); - } - - private async Task SetupCustomerAsync(Provider provider) - { - if (string.IsNullOrEmpty(provider.GatewayCustomerId)) - { - var organizations = await GetClientsAsync(provider.Id); - - var sampleOrganization = organizations.FirstOrDefault(organization => !string.IsNullOrEmpty(organization.GatewayCustomerId)); - - if (sampleOrganization == null) - { - logger.LogInformation( - "CB: Could not find sample organization for provider ({ProviderID}) that has a Stripe customer", - provider.Id); - - return; - } - - var taxInfo = await paymentService.GetTaxInfoAsync(sampleOrganization); - - // Create dummy payment source for legacy migration - this migrator is deprecated and will be removed - var dummyPaymentSource = new TokenizedPaymentSource(PaymentMethodType.Card, "migration_dummy_token"); - - var customer = await providerBillingService.SetupCustomer(provider, taxInfo, dummyPaymentSource); - - await stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Coupon = StripeConstants.CouponIDs.LegacyMSPDiscount - }); - - provider.GatewayCustomerId = customer.Id; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Setup Stripe customer for provider ({ProviderID})", provider.Id); - } - else - { - logger.LogInformation("CB: Stripe customer already exists for provider ({ProviderID})", provider.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CustomerSetup); - } - - private async Task SetupSubscriptionAsync(Provider provider) - { - if (string.IsNullOrEmpty(provider.GatewaySubscriptionId)) - { - if (!string.IsNullOrEmpty(provider.GatewayCustomerId)) - { - var subscription = await providerBillingService.SetupSubscription(provider); - - provider.GatewaySubscriptionId = subscription.Id; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Setup Stripe subscription for provider ({ProviderID})", provider.Id); - } - else - { - logger.LogInformation( - "CB: Could not set up Stripe subscription for provider ({ProviderID}) with no Stripe customer", - provider.Id); - - return; - } - } - else - { - logger.LogInformation("CB: Stripe subscription already exists for provider ({ProviderID})", provider.Id); - - var providerPlans = await providerPlanRepository.GetByProviderId(provider.Id); - - var enterpriseSeatMinimum = providerPlans - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.EnterpriseMonthly)? - .SeatMinimum ?? 0; - - var teamsSeatMinimum = providerPlans - .FirstOrDefault(providerPlan => providerPlan.PlanType == PlanType.TeamsMonthly)? - .SeatMinimum ?? 0; - - var updateSeatMinimumsCommand = new UpdateProviderSeatMinimumsCommand( - provider, - [ - (Plan: PlanType.EnterpriseMonthly, SeatsMinimum: enterpriseSeatMinimum), - (Plan: PlanType.TeamsMonthly, SeatsMinimum: teamsSeatMinimum) - ]); - await providerBillingService.UpdateSeatMinimums(updateSeatMinimumsCommand); - - logger.LogInformation( - "CB: Updated Stripe subscription for provider ({ProviderID}) with current seat minimums", provider.Id); - } - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.SubscriptionSetup); - } - - private async Task ApplyCreditAsync(Provider provider) - { - var organizations = await GetClientsAsync(provider.Id); - - var organizationCustomers = - await Task.WhenAll(organizations.Select(organization => stripeAdapter.CustomerGetAsync(organization.GatewayCustomerId))); - - var organizationCancellationCredit = organizationCustomers.Sum(customer => customer.Balance); - - if (organizationCancellationCredit != 0) - { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, - new CustomerBalanceTransactionCreateOptions - { - Amount = organizationCancellationCredit, - Currency = "USD", - Description = "Unused, prorated time for client organization subscriptions." - }); - } - - var migrationRecords = await Task.WhenAll(organizations.Select(organization => - clientOrganizationMigrationRecordRepository.GetByOrganizationId(organization.Id))); - - var legacyOrganizationMigrationRecords = migrationRecords.Where(migrationRecord => - migrationRecord.PlanType is - PlanType.EnterpriseAnnually2020 or - PlanType.TeamsAnnually2020); - - var legacyOrganizationCredit = legacyOrganizationMigrationRecords.Sum(migrationRecord => migrationRecord.Seats) * 12 * -100; - - if (legacyOrganizationCredit < 0) - { - await stripeAdapter.CustomerBalanceTransactionCreate(provider.GatewayCustomerId, - new CustomerBalanceTransactionCreateOptions - { - Amount = legacyOrganizationCredit, - Currency = "USD", - Description = "1 year rebate for legacy client organizations." - }); - } - - logger.LogInformation("CB: Applied {Credit} credit to provider ({ProviderID})", organizationCancellationCredit + legacyOrganizationCredit, provider.Id); - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.CreditApplied); - } - - private async Task UpdateProviderAsync(Provider provider) - { - provider.Status = ProviderStatusType.Billable; - - await providerRepository.ReplaceAsync(provider); - - logger.LogInformation("CB: Completed migration for provider ({ProviderID})", provider.Id); - - await migrationTrackerCache.UpdateTrackingStatus(provider.Id, ProviderMigrationProgress.Completed); - } - - #endregion - - #region Utilities - - private async Task> GetClientsAsync(Guid providerId) - { - var providerOrganizations = await providerOrganizationRepository.GetManyDetailsByProviderAsync(providerId); - - return (await Task.WhenAll(providerOrganizations.Select(providerOrganization => - organizationRepository.GetByIdAsync(providerOrganization.OrganizationId)))) - .ToList(); - } - - private async Task GetProviderAsync(Guid providerId) - { - var provider = await providerRepository.GetByIdAsync(providerId); - - if (provider == null) - { - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it does not exist", providerId); - - return null; - } - - if (provider.Type != ProviderType.Msp) - { - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not an MSP", providerId); - - return null; - } - - if (provider.Status == ProviderStatusType.Created) - { - return provider; - } - - logger.LogWarning("CB: Cannot migrate provider ({ProviderID}) as it is not in the 'Created' state", providerId); - - return null; - } - - private static bool IsEnterprise(Organization organization) => organization.Plan.Contains("Enterprise"); - private static bool IsTeams(Organization organization) => organization.Plan.Contains("Teams"); - - #endregion -} diff --git a/src/Core/Billing/Providers/Services/IProviderBillingService.cs b/src/Core/Billing/Providers/Services/IProviderBillingService.cs index 173249f79f..57d68db038 100644 --- a/src/Core/Billing/Providers/Services/IProviderBillingService.cs +++ b/src/Core/Billing/Providers/Services/IProviderBillingService.cs @@ -5,10 +5,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Providers.Entities; using Bit.Core.Billing.Providers.Models; using Bit.Core.Billing.Tax.Models; -using Bit.Core.Models.Business; using Stripe; namespace Bit.Core.Billing.Providers.Services; @@ -79,16 +79,16 @@ public interface IProviderBillingService int seatAdjustment); /// - /// For use during the provider setup process, this method creates a Stripe for the specified utilizing the provided . + /// For use during the provider setup process, this method creates a Stripe for the specified utilizing the provided and . /// /// The to create a Stripe customer for. - /// The to use for calculating the customer's automatic tax. - /// The (ex. Credit Card) to attach to the customer. + /// The (e.g., Credit Card, Bank Account, or PayPal) to attach to the customer. + /// The containing the customer's billing information including address and tax ID details. /// The newly created for the . Task SetupCustomer( Provider provider, - TaxInfo taxInfo, - TokenizedPaymentSource tokenizedPaymentSource); + TokenizedPaymentMethod paymentMethod, + BillingAddress billingAddress); /// /// For use during the provider setup process, this method starts a Stripe for the given . diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 9db18278b6..3170060de4 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -3,8 +3,10 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Tax.Models; using Bit.Core.Entities; using Bit.Core.Enums; @@ -29,7 +31,8 @@ public class PremiumUserBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository) : IPremiumUserBillingService + IUserRepository userRepository, + IPricingClient pricingClient) : IPremiumUserBillingService { public async Task Credit(User user, decimal amount) { @@ -108,7 +111,7 @@ public class PremiumUserBillingService( when subscription.Status == StripeConstants.SubscriptionStatus.Active: { user.Premium = true; - user.PremiumExpirationDate = subscription.CurrentPeriodEnd; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); break; } } @@ -300,11 +303,13 @@ public class PremiumUserBillingService( Customer customer, int? storage) { + var premiumPlan = await pricingClient.GetAvailablePremiumPlan(); + var subscriptionItemOptionsList = new List { new () { - Price = StripeConstants.Prices.PremiumAnnually, + Price = premiumPlan.Seat.StripePriceId, Quantity = 1 } }; @@ -313,7 +318,7 @@ public class PremiumUserBillingService( { subscriptionItemOptionsList.Add(new SubscriptionItemOptions { - Price = StripeConstants.Prices.StoragePlanPersonal, + Price = premiumPlan.Storage.StripePriceId, Quantity = storage }); } diff --git a/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs new file mode 100644 index 0000000000..ee60597601 --- /dev/null +++ b/src/Core/Billing/Subscriptions/Commands/RestartSubscriptionCommand.cs @@ -0,0 +1,93 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using OneOf.Types; +using Stripe; + +namespace Bit.Core.Billing.Subscriptions.Commands; + +using static StripeConstants; + +public interface IRestartSubscriptionCommand +{ + Task> Run( + ISubscriber subscriber); +} + +public class RestartSubscriptionCommand( + IOrganizationRepository organizationRepository, + IProviderRepository providerRepository, + IStripeAdapter stripeAdapter, + ISubscriberService subscriberService, + IUserRepository userRepository) : IRestartSubscriptionCommand +{ + public async Task> Run( + ISubscriber subscriber) + { + var existingSubscription = await subscriberService.GetSubscription(subscriber); + + if (existingSubscription is not { Status: SubscriptionStatus.Canceled }) + { + return new BadRequest("Cannot restart a subscription that is not canceled."); + } + + var options = new SubscriptionCreateOptions + { + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, + CollectionMethod = CollectionMethod.ChargeAutomatically, + Customer = existingSubscription.CustomerId, + Items = existingSubscription.Items.Select(subscriptionItem => new SubscriptionItemOptions + { + Price = subscriptionItem.Price.Id, + Quantity = subscriptionItem.Quantity + }).ToList(), + Metadata = existingSubscription.Metadata, + OffSession = true, + TrialPeriodDays = 0 + }; + + var subscription = await stripeAdapter.SubscriptionCreateAsync(options); + await EnableAsync(subscriber, subscription); + return new None(); + } + + private async Task EnableAsync(ISubscriber subscriber, Subscription subscription) + { + switch (subscriber) + { + case Organization organization: + { + organization.GatewaySubscriptionId = subscription.Id; + organization.Enabled = true; + organization.ExpirationDate = subscription.GetCurrentPeriodEnd(); + organization.RevisionDate = DateTime.UtcNow; + await organizationRepository.ReplaceAsync(organization); + break; + } + case Provider provider: + { + provider.GatewaySubscriptionId = subscription.Id; + provider.Enabled = true; + provider.RevisionDate = DateTime.UtcNow; + await providerRepository.ReplaceAsync(provider); + break; + } + case User user: + { + user.GatewaySubscriptionId = subscription.Id; + user.Premium = true; + user.PremiumExpirationDate = subscription.GetCurrentPeriodEnd(); + user.RevisionDate = DateTime.UtcNow; + await userRepository.ReplaceAsync(user); + break; + } + } + } +} diff --git a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs b/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs deleted file mode 100644 index 94d3724d73..0000000000 --- a/src/Core/Billing/Tax/Commands/PreviewTaxAmountCommand.cs +++ /dev/null @@ -1,136 +0,0 @@ -using Bit.Core.Billing.Commands; -using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; -using Bit.Core.Billing.Pricing; -using Bit.Core.Billing.Tax.Services; -using Bit.Core.Services; -using Microsoft.Extensions.Logging; -using Stripe; - -namespace Bit.Core.Billing.Tax.Commands; - -public interface IPreviewTaxAmountCommand -{ - Task> Run(OrganizationTrialParameters parameters); -} - -public class PreviewTaxAmountCommand( - ILogger logger, - IPricingClient pricingClient, - IStripeAdapter stripeAdapter, - ITaxService taxService) : BaseBillingCommand(logger), IPreviewTaxAmountCommand -{ - protected override Conflict DefaultConflict - => new("We had a problem calculating your tax obligation. Please contact support for assistance."); - - public Task> Run(OrganizationTrialParameters parameters) - => HandleAsync(async () => - { - var (planType, productType, taxInformation) = parameters; - - var plan = await pricingClient.GetPlanOrThrow(planType); - - var options = new InvoiceCreatePreviewOptions - { - Currency = "usd", - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - Country = taxInformation.Country, - PostalCode = taxInformation.PostalCode - } - }, - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new InvoiceSubscriptionDetailsItemOptions - { - Price = plan.HasNonSeatBasedPasswordManagerPlan() - ? plan.PasswordManager.StripePlanId - : plan.PasswordManager.StripeSeatPlanId, - Quantity = 1 - } - ] - } - }; - - if (productType == ProductType.SecretsManager) - { - options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions - { - Price = plan.SecretsManager.StripeSeatPlanId, - Quantity = 1 - }); - - options.Coupon = StripeConstants.CouponIDs.SecretsManagerStandalone; - } - - if (!string.IsNullOrEmpty(taxInformation.TaxId)) - { - var taxIdType = taxService.GetStripeTaxCode( - taxInformation.Country, - taxInformation.TaxId); - - if (string.IsNullOrEmpty(taxIdType)) - { - return new BadRequest( - "We couldn't find a corresponding tax ID type for the tax ID you provided. Please try again or contact support for assistance."); - } - - options.CustomerDetails.TaxIds = - [ - new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = taxInformation.TaxId } - ]; - - if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) - { - options.CustomerDetails.TaxIds.Add(new InvoiceCustomerDetailsTaxIdOptions - { - Type = StripeConstants.TaxIdType.EUVAT, - Value = $"ES{parameters.TaxInformation.TaxId}" - }); - } - } - - options.AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }; - if (parameters.PlanType.IsBusinessProductTierType() && - parameters.TaxInformation.Country != Core.Constants.CountryAbbreviations.UnitedStates) - { - options.CustomerDetails.TaxExempt = StripeConstants.TaxExempt.Reverse; - } - - var invoice = await stripeAdapter.InvoiceCreatePreviewAsync(options); - return Convert.ToDecimal(invoice.Tax) / 100; - }); -} - -#region Command Parameters - -public record OrganizationTrialParameters -{ - public required PlanType PlanType { get; set; } - public required ProductType ProductType { get; set; } - public required TaxInformationDTO TaxInformation { get; set; } - - public void Deconstruct( - out PlanType planType, - out ProductType productType, - out TaxInformationDTO taxInformation) - { - planType = PlanType; - productType = ProductType; - taxInformation = TaxInformation; - } - - public record TaxInformationDTO - { - public required string Country { get; set; } - public required string PostalCode { get; set; } - public string? TaxId { get; set; } - } -} - -#endregion diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index abe489de0f..d41548b5d8 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -70,6 +70,17 @@ public static class Constants /// public const string UnitedStates = "US"; } + + + /// + /// Constants for our browser extensions IDs + /// + public static class BrowserExtensions + { + public const string ChromeId = "chrome-extension://nngceckbapebfimnlniiiahkandclblb/"; + public const string EdgeId = "chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh/"; + public const string OperaId = "chrome-extension://ccnckbpmaceehanjmeomladnmlffdjgn/"; + } } public static class AuthConstants @@ -124,16 +135,18 @@ public static class AuthenticationSchemes public static class FeatureFlagKeys { /* Admin Console Team */ - public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; - public const string LimitItemDeletion = "pm-15493-restrict-item-deletion-to-can-manage-permission"; public const string PolicyRequirements = "pm-14439-policy-requirements"; public const string ScimInviteUserOptimization = "pm-16811-optimize-invite-user-flow-to-fail-fast"; - public const string EventBasedOrganizationIntegrations = "event-based-organization-integrations"; - public const string SeparateCustomRolePermissions = "pm-19917-separate-custom-role-permissions"; public const string CreateDefaultLocation = "pm-19467-create-default-location"; + public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; - public const string CipherRepositoryBulkResourceCreation = "pm-24951-cipher-repository-bulk-resource-creation-service"; - public const string CollectionVaultRefactor = "pm-25030-resolve-ts-upgrade-errors"; + public const string AccountRecoveryCommand = "pm-25581-prevent-provider-account-recovery"; + public const string PolicyValidatorsRefactor = "pm-26423-refactor-policy-side-effects"; + + /* Architecture */ + public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1"; + public const string DesktopMigrationMilestone2 = "desktop-ui-migration-milestone-2"; + public const string DesktopMigrationMilestone3 = "desktop-ui-migration-milestone-3"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; @@ -143,9 +156,12 @@ public static class FeatureFlagKeys public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; public const string Otp6Digits = "pm-18612-otp-6-digits"; public const string FailedTwoFactorEmail = "pm-24425-send-2fa-failed-email"; + public const string PM24579_PreventSsoOnExistingNonCompliantUsers = "pm-24579-prevent-sso-on-existing-non-compliant-users"; public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; public const string PM23174ManageAccountRecoveryPermissionDrivesTheNeedToSetMasterPassword = "pm-23174-manage-account-recovery-permission-drives-the-need-to-set-master-password"; + public const string RecoveryCodeSupportForSsoRequiredUsers = "pm-21153-recovery-code-support-for-sso-required"; + public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; /* Autofill Team */ public const string IdpAutoSubmitLogin = "idp-auto-submit-login"; @@ -153,6 +169,7 @@ public static class FeatureFlagKeys public const string InlineMenuFieldQualification = "inline-menu-field-qualification"; public const string InlineMenuPositioningImprovements = "inline-menu-positioning-improvements"; public const string SSHAgent = "ssh-agent"; + public const string SSHAgentV2 = "ssh-agent-v2"; public const string SSHVersionCheckQAOverride = "ssh-version-check-qa-override"; public const string GenerateIdentityFillScriptRefactor = "generate-identity-fill-script-refactor"; public const string DelayFido2PageScriptInitWithinMv2 = "delay-fido2-page-script-init-within-mv2"; @@ -165,15 +182,19 @@ public static class FeatureFlagKeys public const string WindowsDesktopAutotype = "windows-desktop-autotype"; /* Billing Team */ - public const string AC2101UpdateTrialInitiationEmail = "AC-2101-update-trial-initiation-email"; public const string TrialPayment = "PM-8163-trial-payment"; - public const string PM17772_AdminInitiatedSponsorships = "pm-17772-admin-initiated-sponsorships"; public const string UsePricingService = "use-pricing-service"; public const string PM19422_AllowAutomaticTaxUpdates = "pm-19422-allow-automatic-tax-updates"; - public const string PM21881_ManagePaymentDetailsOutsideCheckout = "pm-21881-manage-payment-details-outside-checkout"; public const string PM21821_ProviderPortalTakeover = "pm-21821-provider-portal-takeover"; public const string PM22415_TaxIDWarnings = "pm-22415-tax-id-warnings"; - public const string PM23385_UseNewPremiumFlow = "pm-23385-use-new-premium-flow"; + public const string PM25379_UseNewOrganizationMetadataStructure = "pm-25379-use-new-organization-metadata-structure"; + public const string PM24996ImplementUpgradeFromFreeDialog = "pm-24996-implement-upgrade-from-free-dialog"; + public const string PM24032_NewNavigationPremiumUpgradeButton = "pm-24032-new-navigation-premium-upgrade-button"; + public const string PM23713_PremiumBadgeOpensNewPremiumUpgradeDialog = "pm-23713-premium-badge-opens-new-premium-upgrade-dialog"; + public const string PremiumUpgradeNewDesign = "pm-24033-updat-premium-subscription-page"; + public const string PM26793_FetchPremiumPriceFromPricingService = "pm-26793-fetch-premium-price-from-pricing-service"; + public const string PM23341_Milestone_2 = "pm-23341-milestone-2"; + public const string PM26462_Milestone_3 = "pm-26462-milestone-3"; /* Key Management Team */ public const string ReturnErrorOnExistingKeypair = "return-error-on-existing-keypair"; @@ -183,28 +204,26 @@ public static class FeatureFlagKeys public const string UserkeyRotationV2 = "userkey-rotation-v2"; public const string SSHKeyItemVaultItem = "ssh-key-vault-item"; public const string UserSdkForDecryption = "use-sdk-for-decryption"; + public const string EnrollAeadOnKeyRotation = "enroll-aead-on-key-rotation"; public const string PM17987_BlockType0 = "pm-17987-block-type-0"; public const string ForceUpdateKDFSettings = "pm-18021-force-update-kdf-settings"; public const string UnlockWithMasterPasswordUnlockData = "pm-23246-unlock-with-master-password-unlock-data"; public const string WindowsBiometricsV2 = "pm-25373-windows-biometrics-v2"; + public const string LinuxBiometricsV2 = "pm-26340-linux-biometrics-v2"; + public const string NoLogoutOnKdfChange = "pm-23995-no-logout-on-kdf-change"; + public const string DisableType0Decryption = "pm-25174-disable-type-0-decryption"; + public const string ConsolidatedSessionTimeoutComponent = "pm-26056-consolidated-session-timeout-component"; /* Mobile Team */ - public const string NativeCarouselFlow = "native-carousel-flow"; - public const string NativeCreateAccountFlow = "native-create-account-flow"; public const string AndroidImportLoginsFlow = "import-logins-flow"; - public const string AppReviewPrompt = "app-review-prompt"; public const string AndroidMutualTls = "mutual-tls"; public const string SingleTapPasskeyCreation = "single-tap-passkey-creation"; public const string SingleTapPasskeyAuthentication = "single-tap-passkey-authentication"; - public const string EnablePMAuthenticatorSync = "enable-pm-bwa-sync"; public const string PM3503_MobileAnonAddySelfHostAlias = "anon-addy-self-host-alias"; public const string PM3553_MobileSimpleLoginSelfHostAlias = "simple-login-self-host-alias"; - public const string EnablePMFlightRecorder = "enable-pm-flight-recorder"; public const string MobileErrorReporting = "mobile-error-reporting"; public const string AndroidChromeAutofill = "android-chrome-autofill"; public const string UserManagedPrivilegedApps = "pm-18970-user-managed-privileged-apps"; - public const string EnablePMPreloginSettings = "enable-pm-prelogin-settings"; - public const string AppIntents = "app-intents"; public const string SendAccess = "pm-19394-send-access-control"; public const string CxpImportMobile = "cxp-import-mobile"; public const string CxpExportMobile = "cxp-export-mobile"; @@ -222,6 +241,7 @@ public static class FeatureFlagKeys public const string DesktopSendUIRefresh = "desktop-send-ui-refresh"; public const string UseSdkPasswordGenerators = "pm-19976-use-sdk-password-generators"; public const string UseChromiumImporter = "pm-23982-chromium-importer"; + public const string ChromiumImporterWithABE = "pm-25855-chromium-importer-abe"; /// /// Enable this flag to output email/OTP authenticated sends from the `GET sends` endpoint. When @@ -236,9 +256,7 @@ public static class FeatureFlagKeys /* Vault Team */ public const string PM8851_BrowserOnboardingNudge = "pm-8851-browser-onboarding-nudge"; public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; - public const string SecurityTasks = "security-tasks"; public const string CipherKeyEncryption = "cipher-key-encryption"; - public const string DesktopCipherForms = "pm-18520-desktop-cipher-forms"; public const string PM19941MigrateCipherDomainToSdk = "pm-19941-migrate-cipher-domain-to-sdk"; public const string EndUserNotifications = "pm-10609-end-user-notifications"; public const string PhishingDetection = "phishing-detection"; @@ -246,12 +264,17 @@ public static class FeatureFlagKeys public const string PM22134SdkCipherListView = "pm-22134-sdk-cipher-list-view"; public const string PM19315EndUserActivationMvp = "pm-19315-end-user-activation-mvp"; public const string PM22136_SdkCipherEncryption = "pm-22136-sdk-cipher-encryption"; + public const string PM23904_RiskInsightsForPremium = "pm-23904-risk-insights-for-premium"; + public const string PM25083_AutofillConfirmFromSearch = "pm-25083-autofill-confirm-from-search"; + public const string VaultLoadingSkeletons = "pm-25081-vault-skeleton-loaders"; /* Innovation Team */ public const string ArchiveVaultItems = "pm-19148-innovation-archive"; /* DIRT Team */ public const string PM22887_RiskInsightsActivityTab = "pm-22887-risk-insights-activity-tab"; + public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike"; + public const string EventDiagnosticLogging = "pm-27666-siem-event-log-debugging"; public static List GetAllKeys() { diff --git a/src/Core/Context/ICurrentContext.cs b/src/Core/Context/ICurrentContext.cs index 417e220ba2..f62a048070 100644 --- a/src/Core/Context/ICurrentContext.cs +++ b/src/Core/Context/ICurrentContext.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Security.Claims; +using System.Security.Claims; using Bit.Core.AdminConsole.Context; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Identity; @@ -12,6 +10,14 @@ using Microsoft.AspNetCore.Http; namespace Bit.Core.Context; +/// +/// Provides information about the current HTTP request and the currently authenticated user (if any). +/// This is often (but not exclusively) parsed from the JWT in the current request. +/// +/// +/// This interface suffers from having too much responsibility; consider whether any new code can go in a more +/// specific class rather than adding it here. +/// public interface ICurrentContext { HttpContext HttpContext { get; set; } @@ -59,8 +65,20 @@ public interface ICurrentContext Task EditSubscription(Guid orgId); Task EditPaymentMethods(Guid orgId); Task ViewBillingHistory(Guid orgId); + /// + /// Returns true if the current user is a member of a provider that manages the specified organization. + /// This generally gives the user administrative privileges for the organization. + /// + /// + /// Task ProviderUserForOrgAsync(Guid orgId); + /// + /// Returns true if the current user is a Provider Admin of the specified provider. + /// bool ProviderProviderAdmin(Guid providerId); + /// + /// Returns true if the current user is a member of the specified provider (with any role). + /// bool ProviderUser(Guid providerId); bool ProviderManageUsers(Guid providerId); bool ProviderAccessEventLogs(Guid providerId); diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index e76af0f8ef..4901c5b43c 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -16,13 +16,15 @@ - + + + - - + + @@ -34,10 +36,12 @@ - + + + @@ -55,7 +59,7 @@ - + @@ -70,7 +74,7 @@ - + diff --git a/src/Core/Dirt/Entities/OrganizationReport.cs b/src/Core/Dirt/Entities/OrganizationReport.cs index a776648b35..9d04180c8d 100644 --- a/src/Core/Dirt/Entities/OrganizationReport.cs +++ b/src/Core/Dirt/Entities/OrganizationReport.cs @@ -11,12 +11,24 @@ public class OrganizationReport : ITableObject public Guid OrganizationId { get; set; } public string ReportData { get; set; } = string.Empty; public DateTime CreationDate { get; set; } = DateTime.UtcNow; - public string ContentEncryptionKey { get; set; } = string.Empty; - - public string? SummaryData { get; set; } = null; - public string? ApplicationData { get; set; } = null; + public string? SummaryData { get; set; } + public string? ApplicationData { get; set; } public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + public int? ApplicationCount { get; set; } + public int? ApplicationAtRiskCount { get; set; } + public int? CriticalApplicationCount { get; set; } + public int? CriticalApplicationAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public int? MemberAtRiskCount { get; set; } + public int? CriticalMemberCount { get; set; } + public int? CriticalMemberAtRiskCount { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? CriticalPasswordCount { get; set; } + public int? CriticalPasswordAtRiskCount { get; set; } + + public void SetNewId() { diff --git a/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs new file mode 100644 index 0000000000..ffef91275a --- /dev/null +++ b/src/Core/Dirt/Models/Data/OrganizationReportMetricsData.cs @@ -0,0 +1,48 @@ +using Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +namespace Bit.Core.Dirt.Reports.Models.Data; + +public class OrganizationReportMetricsData +{ + public Guid OrganizationId { get; set; } + public int? ApplicationCount { get; set; } + public int? ApplicationAtRiskCount { get; set; } + public int? CriticalApplicationCount { get; set; } + public int? CriticalApplicationAtRiskCount { get; set; } + public int? MemberCount { get; set; } + public int? MemberAtRiskCount { get; set; } + public int? CriticalMemberCount { get; set; } + public int? CriticalMemberAtRiskCount { get; set; } + public int? PasswordCount { get; set; } + public int? PasswordAtRiskCount { get; set; } + public int? CriticalPasswordCount { get; set; } + public int? CriticalPasswordAtRiskCount { get; set; } + + public static OrganizationReportMetricsData From(Guid organizationId, OrganizationReportMetricsRequest? request) + { + if (request == null) + { + return new OrganizationReportMetricsData + { + OrganizationId = organizationId + }; + } + + return new OrganizationReportMetricsData + { + OrganizationId = organizationId, + ApplicationCount = request.ApplicationCount, + ApplicationAtRiskCount = request.ApplicationAtRiskCount, + CriticalApplicationCount = request.CriticalApplicationCount, + CriticalApplicationAtRiskCount = request.CriticalApplicationAtRiskCount, + MemberCount = request.MemberCount, + MemberAtRiskCount = request.MemberAtRiskCount, + CriticalMemberCount = request.CriticalMemberCount, + CriticalMemberAtRiskCount = request.CriticalMemberAtRiskCount, + PasswordCount = request.PasswordCount, + PasswordAtRiskCount = request.PasswordAtRiskCount, + CriticalPasswordCount = request.CriticalPasswordCount, + CriticalPasswordAtRiskCount = request.CriticalPasswordAtRiskCount + }; + } +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs index f0477806d8..236560487e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/AddOrganizationReportCommand.cs @@ -35,14 +35,28 @@ public class AddOrganizationReportCommand : IAddOrganizationReportCommand throw new BadRequestException(errorMessage); } + var requestMetrics = request.Metrics ?? new OrganizationReportMetricsRequest(); + var organizationReport = new OrganizationReport { OrganizationId = request.OrganizationId, - ReportData = request.ReportData, + ReportData = request.ReportData ?? string.Empty, CreationDate = DateTime.UtcNow, - ContentEncryptionKey = request.ContentEncryptionKey, + ContentEncryptionKey = request.ContentEncryptionKey ?? string.Empty, SummaryData = request.SummaryData, ApplicationData = request.ApplicationData, + ApplicationCount = requestMetrics.ApplicationCount, + ApplicationAtRiskCount = requestMetrics.ApplicationAtRiskCount, + CriticalApplicationCount = requestMetrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = requestMetrics.CriticalApplicationAtRiskCount, + MemberCount = requestMetrics.MemberCount, + MemberAtRiskCount = requestMetrics.MemberAtRiskCount, + CriticalMemberCount = requestMetrics.CriticalMemberCount, + CriticalMemberAtRiskCount = requestMetrics.CriticalMemberAtRiskCount, + PasswordCount = requestMetrics.PasswordCount, + PasswordAtRiskCount = requestMetrics.PasswordAtRiskCount, + CriticalPasswordCount = requestMetrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = requestMetrics.CriticalPasswordAtRiskCount, RevisionDate = DateTime.UtcNow }; diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs index 2a8c0203f9..eecc84c522 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/AddOrganizationReportRequest.cs @@ -1,16 +1,15 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class AddOrganizationReportRequest { public Guid OrganizationId { get; set; } - public string ReportData { get; set; } + public string? ReportData { get; set; } - public string ContentEncryptionKey { get; set; } + public string? ContentEncryptionKey { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } + + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs new file mode 100644 index 0000000000..9403a5f1c2 --- /dev/null +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/OrganizationReportMetricsRequest.cs @@ -0,0 +1,31 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; + +public class OrganizationReportMetricsRequest +{ + [JsonPropertyName("totalApplicationCount")] + public int? ApplicationCount { get; set; } = null; + [JsonPropertyName("totalAtRiskApplicationCount")] + public int? ApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalApplicationCount")] + public int? CriticalApplicationCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskApplicationCount")] + public int? CriticalApplicationAtRiskCount { get; set; } = null; + [JsonPropertyName("totalMemberCount")] + public int? MemberCount { get; set; } = null; + [JsonPropertyName("totalAtRiskMemberCount")] + public int? MemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalMemberCount")] + public int? CriticalMemberCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskMemberCount")] + public int? CriticalMemberAtRiskCount { get; set; } = null; + [JsonPropertyName("totalPasswordCount")] + public int? PasswordCount { get; set; } = null; + [JsonPropertyName("totalAtRiskPasswordCount")] + public int? PasswordAtRiskCount { get; set; } = null; + [JsonPropertyName("totalCriticalPasswordCount")] + public int? CriticalPasswordCount { get; set; } = null; + [JsonPropertyName("totalCriticalAtRiskPasswordCount")] + public int? CriticalPasswordAtRiskCount { get; set; } = null; +} diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs index ab4fcc5921..e549a3f120 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportApplicationDataRequest.cs @@ -1,11 +1,8 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportApplicationDataRequest { public Guid Id { get; set; } public Guid OrganizationId { get; set; } - public string ApplicationData { get; set; } + public string? ApplicationData { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs index b0e555fcef..27358537c2 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/Requests/UpdateOrganizationReportSummaryRequest.cs @@ -1,11 +1,9 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; +namespace Bit.Core.Dirt.Reports.ReportFeatures.Requests; public class UpdateOrganizationReportSummaryRequest { public Guid OrganizationId { get; set; } public Guid ReportId { get; set; } - public string SummaryData { get; set; } + public string? SummaryData { get; set; } + public OrganizationReportMetricsRequest? Metrics { get; set; } } diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs index 67ec49d004..375b766a0e 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportApplicationDataCommand.cs @@ -53,7 +53,7 @@ public class UpdateOrganizationReportApplicationDataCommand : IUpdateOrganizatio throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData); + var updatedReport = await _organizationReportRepo.UpdateApplicationDataAsync(request.OrganizationId, request.Id, request.ApplicationData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report application data {reportId} for organization {organizationId}", request.Id, request.OrganizationId); diff --git a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs index 6859814d65..5d0f2670ca 100644 --- a/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs +++ b/src/Core/Dirt/Reports/ReportFeatures/UpdateOrganizationReportSummaryCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Dirt.Entities; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Reports.ReportFeatures.Interfaces; using Bit.Core.Dirt.Reports.ReportFeatures.Requests; using Bit.Core.Dirt.Repositories; @@ -53,7 +54,8 @@ public class UpdateOrganizationReportSummaryCommand : IUpdateOrganizationReportS throw new BadRequestException("Organization report does not belong to the specified organization"); } - var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData); + await _organizationReportRepo.UpdateMetricsAsync(request.ReportId, OrganizationReportMetricsData.From(request.OrganizationId, request.Metrics)); + var updatedReport = await _organizationReportRepo.UpdateSummaryDataAsync(request.OrganizationId, request.ReportId, request.SummaryData ?? string.Empty); _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully updated organization report summary {reportId} for organization {organizationId}", request.ReportId, request.OrganizationId); diff --git a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs index 9687173716..b4c2f90566 100644 --- a/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs +++ b/src/Core/Dirt/Repositories/IOrganizationReportRepository.cs @@ -1,5 +1,6 @@ using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Repositories; namespace Bit.Core.Dirt.Repositories; @@ -21,5 +22,8 @@ public interface IOrganizationReportRepository : IRepository GetApplicationDataAsync(Guid reportId); Task UpdateApplicationDataAsync(Guid orgId, Guid reportId, string applicationData); + + // Metrics methods + Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics); } diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index 12c527ed78..fec9b80d8e 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -3,6 +3,7 @@ using System.Text.Json; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; @@ -21,6 +22,9 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac [MaxLength(256)] public string Email { get; set; } = null!; public bool EmailVerified { get; set; } + /// + /// The server-side master-password hash + /// [MaxLength(300)] public string? MasterPassword { get; set; } [MaxLength(50)] @@ -41,9 +45,30 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac /// organization membership. /// public DateTime AccountRevisionDate { get; set; } = DateTime.UtcNow; + /// + /// The master-password-sealed user key. + /// public string? Key { get; set; } + /// + /// The raw public key, without a signature from the user's signature key. + /// public string? PublicKey { get; set; } + /// + /// User key wrapped private key. + /// public string? PrivateKey { get; set; } + /// + /// The public key, signed by the user's signature key. + /// + public string? SignedPublicKey { get; set; } + /// + /// The security version is included in the security state, but needs COSE parsing + /// + public int? SecurityVersion { get; set; } + /// + /// The security state is a signed object attesting to the version of the user's account. + /// + public string? SecurityState { get; set; } public bool Premium { get; set; } public DateTime? PremiumExpirationDate { get; set; } public DateTime? RenewalReminderDate { get; set; } @@ -180,6 +205,12 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac return Premium; } + public int GetSecurityVersion() + { + // If no security version is set, it is version 1. The minimum initialized version is 2. + return SecurityVersion ?? 1; + } + /// /// Serializes the C# object to the User.TwoFactorProviders property in JSON format. /// @@ -243,4 +274,14 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac { return MasterPassword != null; } + + public PublicKeyEncryptionKeyPairData GetPublicKeyEncryptionKeyPair() + { + if (string.IsNullOrWhiteSpace(PrivateKey) || string.IsNullOrWhiteSpace(PublicKey)) + { + throw new InvalidOperationException("User public key encryption key pair is not fully initialized."); + } + + return new PublicKeyEncryptionKeyPairData(PrivateKey, PublicKey, SignedPublicKey); + } } diff --git a/src/Core/Enums/PushNotificationLogOutReason.cs b/src/Core/Enums/PushNotificationLogOutReason.cs new file mode 100644 index 0000000000..a24f790305 --- /dev/null +++ b/src/Core/Enums/PushNotificationLogOutReason.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Enums; + +public enum PushNotificationLogOutReason : byte +{ + KdfChange = 0 +} diff --git a/src/Core/Jobs/BaseJobsHostedService.cs b/src/Core/Jobs/BaseJobsHostedService.cs index 3e7bce7e0f..8b74052f8f 100644 --- a/src/Core/Jobs/BaseJobsHostedService.cs +++ b/src/Core/Jobs/BaseJobsHostedService.cs @@ -107,7 +107,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable throw new Exception("Job failed to start after 10 retries."); } - _logger.LogWarning($"Exception while trying to schedule job: {job.FullName}, {e}"); + _logger.LogWarning(e, "Exception while trying to schedule job: {JobName}", job.FullName); var random = new Random(); await Task.Delay(random.Next(50, 250)); } @@ -125,7 +125,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable continue; } - _logger.LogInformation($"Deleting old job with key {key}"); + _logger.LogInformation("Deleting old job with key {Key}", key); await _scheduler.DeleteJob(key); } @@ -138,7 +138,7 @@ public abstract class BaseJobsHostedService : IHostedService, IDisposable continue; } - _logger.LogInformation($"Unscheduling old trigger with key {key}"); + _logger.LogInformation("Unscheduling old trigger with key {Key}", key); await _scheduler.UnscheduleJob(key); } } diff --git a/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs b/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs new file mode 100644 index 0000000000..dada9e0d7a --- /dev/null +++ b/src/Core/KeyManagement/Entities/UserSignatureKeyPair.cs @@ -0,0 +1,30 @@ +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Enums; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; + + +namespace Bit.Core.KeyManagement.Entities; + +public class UserSignatureKeyPair : ITableObject, IRevisable +{ + public Guid Id { get; set; } + public Guid UserId { get; set; } + public SignatureAlgorithm SignatureAlgorithm { get; set; } + + public required string VerifyingKey { get; set; } + public required string SigningKey { get; set; } + + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + Id = CoreHelpers.GenerateComb(); + } + + public SignatureKeyPairData ToSignatureKeyPairData() + { + return new SignatureKeyPairData(SignatureAlgorithm, SigningKey, VerifyingKey); + } +} diff --git a/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs b/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs new file mode 100644 index 0000000000..9216c3f489 --- /dev/null +++ b/src/Core/KeyManagement/Enums/SignatureAlgorithm.cs @@ -0,0 +1,9 @@ +namespace Bit.Core.KeyManagement.Enums; + +// +// Represents the algorithm / digital signature scheme used for a signature key pair. +// +public enum SignatureAlgorithm : byte +{ + Ed25519 = 0 +} diff --git a/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs b/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs index fe736f9ac6..83e47c4931 100644 --- a/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs +++ b/src/Core/KeyManagement/Kdf/Implementations/ChangeKdfCommand.cs @@ -1,4 +1,5 @@ using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Platform.Push; @@ -18,17 +19,22 @@ public class ChangeKdfCommand : IChangeKdfCommand private readonly IUserRepository _userRepository; private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly ILogger _logger; + private readonly IFeatureService _featureService; - public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, IUserRepository userRepository, IdentityErrorDescriber describer, ILogger logger) + public ChangeKdfCommand(IUserService userService, IPushNotificationService pushService, + IUserRepository userRepository, IdentityErrorDescriber describer, ILogger logger, + IFeatureService featureService) { _userService = userService; _pushService = pushService; _userRepository = userRepository; _identityErrorDescriber = describer; _logger = logger; + _featureService = featureService; } - public async Task ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData) + public async Task ChangeKdfAsync(User user, string masterPasswordAuthenticationHash, + MasterPasswordAuthenticationData authenticationData, MasterPasswordUnlockData unlockData) { ArgumentNullException.ThrowIfNull(user); if (!await _userService.CheckPasswordAsync(user, masterPasswordAuthenticationHash)) @@ -37,8 +43,8 @@ public class ChangeKdfCommand : IChangeKdfCommand } // Validate to prevent user account from becoming un-decryptable from invalid parameters - // - // Prevent a de-synced salt value from creating an un-decryptable unlock method + // + // Prevent a de-synced salt value from creating an un-decryptable unlock method authenticationData.ValidateSaltUnchangedForUser(user); unlockData.ValidateSaltUnchangedForUser(user); @@ -47,12 +53,15 @@ public class ChangeKdfCommand : IChangeKdfCommand { throw new BadRequestException("KDF settings must be equal for authentication and unlock."); } + var validationErrors = KdfSettingsValidator.Validate(unlockData.Kdf); if (validationErrors.Any()) { throw new BadRequestException("KDF settings are invalid."); } + var logoutOnKdfChange = !_featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange); + // Update the user with the new KDF settings // This updates the authentication data and unlock data for the user separately. Currently these still // use shared values for KDF settings and salt. @@ -68,7 +77,8 @@ public class ChangeKdfCommand : IChangeKdfCommand // This entire operation MUST be atomic to prevent a user from being locked out of their account. // Salt is ensured to be the same as unlock data, and the value stored in the account and not updated. // KDF is ensured to be the same as unlock data above and updated below. - var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash); + var result = await _userService.UpdatePasswordHash(user, authenticationData.MasterPasswordAuthenticationHash, + refreshStamp: logoutOnKdfChange); if (!result.Succeeded) { _logger.LogWarning("Change KDF failed for user {userId}.", user.Id); @@ -88,7 +98,17 @@ public class ChangeKdfCommand : IChangeKdfCommand user.LastKdfChangeDate = now; await _userRepository.ReplaceAsync(user); - await _pushService.PushLogOutAsync(user.Id); + if (logoutOnKdfChange) + { + await _pushService.PushLogOutAsync(user.Id); + } + else + { + // Clients that support the new feature flag will ignore the logout when it matches the reason and the feature flag is enabled. + await _pushService.PushLogOutAsync(user.Id, reason: PushNotificationLogOutReason.KdfChange); + await _pushService.PushSyncSettingsAsync(user.Id); + } + return IdentityResult.Success; } } diff --git a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs index e4ebdb4860..0e551c5d0e 100644 --- a/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs +++ b/src/Core/KeyManagement/KeyManagementServiceCollectionExtensions.cs @@ -2,6 +2,8 @@ using Bit.Core.KeyManagement.Commands.Interfaces; using Bit.Core.KeyManagement.Kdf; using Bit.Core.KeyManagement.Kdf.Implementations; +using Bit.Core.KeyManagement.Queries; +using Bit.Core.KeyManagement.Queries.Interfaces; using Microsoft.Extensions.DependencyInjection; namespace Bit.Core.KeyManagement; @@ -11,6 +13,7 @@ public static class KeyManagementServiceCollectionExtensions public static void AddKeyManagementServices(this IServiceCollection services) { services.AddKeyManagementCommands(); + services.AddKeyManagementQueries(); services.AddSendPasswordServices(); } @@ -19,4 +22,9 @@ public static class KeyManagementServiceCollectionExtensions services.AddScoped(); services.AddScoped(); } + + private static void AddKeyManagementQueries(this IServiceCollection services) + { + services.AddScoped(); + } } diff --git a/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs b/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs new file mode 100644 index 0000000000..1acb52146e --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Request/SecurityStateModel.cs @@ -0,0 +1,32 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Models.Api.Request; + +public class SecurityStateModel +{ + [StringLength(1000)] + [JsonPropertyName("securityState")] + public required string SecurityState { get; set; } + [JsonPropertyName("securityVersion")] + public required int SecurityVersion { get; set; } + + public SecurityStateData ToSecurityState() + { + return new SecurityStateData + { + SecurityState = SecurityState, + SecurityVersion = SecurityVersion + }; + } + + public static SecurityStateModel FromSecurityStateData(SecurityStateData data) + { + return new SecurityStateModel + { + SecurityState = data.SecurityState, + SecurityVersion = data.SecurityVersion + }; + } +} diff --git a/src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs similarity index 75% rename from src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs rename to src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs index f7d5dee852..eebed83485 100644 --- a/src/Core/KeyManagement/Models/Response/MasterPasswordUnlockResponseModel.cs +++ b/src/Core/KeyManagement/Models/Api/Response/MasterPasswordUnlockResponseModel.cs @@ -2,11 +2,15 @@ using Bit.Core.Enums; using Bit.Core.Utilities; -namespace Bit.Core.KeyManagement.Models.Response; +namespace Bit.Core.KeyManagement.Models.Api.Response; public class MasterPasswordUnlockResponseModel { public required MasterPasswordUnlockKdfResponseModel Kdf { get; init; } + /// + /// The user's symmetric key encrypted with their master key. + /// Also known as "MasterKeyWrappedUserKey" + /// [EncryptedString] public required string MasterKeyEncryptedUserKey { get; init; } [StringLength(256)] public required string Salt { get; init; } } diff --git a/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs new file mode 100644 index 0000000000..bcee4c0ada --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PrivateKeysResponseModel.cs @@ -0,0 +1,48 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +/// +/// This response model is used to return the asymmetric encryption keys, +/// and signature keys of an entity. This includes the private keys of the key pairs, +/// (private key, signing key), and the public keys of the key pairs (unsigned public key, +/// signed public key, verification key). +/// +public class PrivateKeysResponseModel : ResponseModel +{ + // Not all accounts have signature keys, but all accounts have public encryption keys. + [JsonPropertyName("signatureKeyPair")] + public SignatureKeyPairResponseModel? SignatureKeyPair { get; set; } + + [JsonPropertyName("publicKeyEncryptionKeyPair")] + public required PublicKeyEncryptionKeyPairResponseModel PublicKeyEncryptionKeyPair { get; set; } + + [JsonPropertyName("securityState")] + public SecurityStateModel? SecurityState { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PrivateKeysResponseModel(UserAccountKeysData accountKeys) : base("privateKeys") + { + ArgumentNullException.ThrowIfNull(accountKeys); + PublicKeyEncryptionKeyPair = new PublicKeyEncryptionKeyPairResponseModel(accountKeys.PublicKeyEncryptionKeyPairData); + + if (accountKeys.SignatureKeyPairData != null && accountKeys.SecurityStateData != null) + { + SignatureKeyPair = new SignatureKeyPairResponseModel(accountKeys.SignatureKeyPairData); + SecurityState = SecurityStateModel.FromSecurityStateData(accountKeys.SecurityStateData!); + } + } + + [JsonConstructor] + public PrivateKeysResponseModel(SignatureKeyPairResponseModel? signatureKeyPair, PublicKeyEncryptionKeyPairResponseModel publicKeyEncryptionKeyPair, SecurityStateModel? securityState) + : base("privateKeys") + { + SignatureKeyPair = signatureKeyPair; + PublicKeyEncryptionKeyPair = publicKeyEncryptionKeyPair ?? throw new ArgumentNullException(nameof(publicKeyEncryptionKeyPair)); + SecurityState = securityState; + } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs new file mode 100644 index 0000000000..e5436b6131 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PublicKeyEncryptionKeyPairResponseModel.cs @@ -0,0 +1,34 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +public class PublicKeyEncryptionKeyPairResponseModel : ResponseModel +{ + [JsonPropertyName("wrappedPrivateKey")] + public required string WrappedPrivateKey { get; set; } + [JsonPropertyName("publicKey")] + public required string PublicKey { get; set; } + [JsonPropertyName("signedPublicKey")] + public string? SignedPublicKey { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeyEncryptionKeyPairResponseModel(PublicKeyEncryptionKeyPairData keyPair) + : base("publicKeyEncryptionKeyPair") + { + WrappedPrivateKey = keyPair.WrappedPrivateKey; + PublicKey = keyPair.PublicKey; + SignedPublicKey = keyPair.SignedPublicKey; + } + + [JsonConstructor] + public PublicKeyEncryptionKeyPairResponseModel(string wrappedPrivateKey, string publicKey, string? signedPublicKey) + : base("publicKeyEncryptionKeyPair") + { + WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey)); + PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey)); + SignedPublicKey = signedPublicKey; + } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs new file mode 100644 index 0000000000..b341a87e3e --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/PublicKeysResponseModel.cs @@ -0,0 +1,30 @@ +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +/// +/// This response model is used to return the public keys of a user, to any other registered user or entity on the server. +/// It can contain public keys (signature/encryption), and proofs between the two. It does not contain (encrypted) private keys. +/// +public class PublicKeysResponseModel : ResponseModel +{ + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeysResponseModel(UserAccountKeysData accountKeys) + : base("publicKeys") + { + ArgumentNullException.ThrowIfNull(accountKeys); + PublicKey = accountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + + if (accountKeys.SignatureKeyPairData != null) + { + SignedPublicKey = accountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + VerifyingKey = accountKeys.SignatureKeyPairData.VerifyingKey; + } + } + + public string? VerifyingKey { get; set; } + public string? SignedPublicKey { get; set; } + public required string PublicKey { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs new file mode 100644 index 0000000000..34d51f8bd4 --- /dev/null +++ b/src/Core/KeyManagement/Models/Api/Response/SignatureKeyPairResponseModel.cs @@ -0,0 +1,32 @@ +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Models.Api; + +namespace Bit.Core.KeyManagement.Models.Api.Response; + + +public class SignatureKeyPairResponseModel : ResponseModel +{ + [JsonPropertyName("wrappedSigningKey")] + public required string WrappedSigningKey { get; set; } + [JsonPropertyName("verifyingKey")] + public required string VerifyingKey { get; set; } + + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public SignatureKeyPairResponseModel(SignatureKeyPairData signatureKeyPair) + : base("signatureKeyPair") + { + ArgumentNullException.ThrowIfNull(signatureKeyPair); + WrappedSigningKey = signatureKeyPair.WrappedSigningKey; + VerifyingKey = signatureKeyPair.VerifyingKey; + } + + + [JsonConstructor] + public SignatureKeyPairResponseModel(string wrappedSigningKey, string verifyingKey) + : base("signatureKeyPair") + { + WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey)); + VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey)); + } +} diff --git a/src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs similarity index 82% rename from src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs rename to src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs index a4d259a00a..536347cea9 100644 --- a/src/Core/KeyManagement/Models/Response/UserDecryptionResponseModel.cs +++ b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs @@ -1,4 +1,4 @@ -namespace Bit.Core.KeyManagement.Models.Response; +namespace Bit.Core.KeyManagement.Models.Api.Response; public class UserDecryptionResponseModel { diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs index c0ae949a3f..1bc7006cef 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs @@ -1,4 +1,5 @@ using Bit.Core.Entities; +using Bit.Core.Exceptions; namespace Bit.Core.KeyManagement.Models.Data; @@ -12,7 +13,7 @@ public class MasterPasswordAuthenticationData { if (user.GetMasterPasswordSalt() != Salt) { - throw new ArgumentException("Invalid master password salt."); + throw new BadRequestException("Invalid master password salt."); } } } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs index e305d92fec..ad3a0b692b 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs @@ -13,6 +13,10 @@ public class MasterPasswordUnlockAndAuthenticationData public required string Email { get; set; } public required string MasterKeyAuthenticationHash { get; set; } + /// + /// The user's symmetric key encrypted with their master key. + /// Also known as "MasterKeyWrappedUserKey" + /// public required string MasterKeyEncryptedUserKey { get; set; } public string? MasterPasswordHint { get; set; } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs index d1ab6f645b..cb18ed2a78 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs @@ -1,6 +1,5 @@ -#nullable enable - -using Bit.Core.Entities; +using Bit.Core.Entities; +using Bit.Core.Exceptions; namespace Bit.Core.KeyManagement.Models.Data; @@ -14,7 +13,7 @@ public class MasterPasswordUnlockData { if (user.GetMasterPasswordSalt() != Salt) { - throw new ArgumentException("Invalid master password salt."); + throw new BadRequestException("Invalid master password salt."); } } } diff --git a/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs b/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs new file mode 100644 index 0000000000..fb8b09d390 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/PublicKeyEncryptionKeyPairData.cs @@ -0,0 +1,20 @@ +using System.Text.Json.Serialization; + +namespace Bit.Core.KeyManagement.Models.Data; + + +public class PublicKeyEncryptionKeyPairData +{ + public required string WrappedPrivateKey { get; set; } + public string? SignedPublicKey { get; set; } + public required string PublicKey { get; set; } + + [JsonConstructor] + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public PublicKeyEncryptionKeyPairData(string wrappedPrivateKey, string publicKey, string? signedPublicKey = null) + { + WrappedPrivateKey = wrappedPrivateKey ?? throw new ArgumentNullException(nameof(wrappedPrivateKey)); + PublicKey = publicKey ?? throw new ArgumentNullException(nameof(publicKey)); + SignedPublicKey = signedPublicKey; + } +} diff --git a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs index 557fb56ff3..19d14b273f 100644 --- a/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs +++ b/src/Core/KeyManagement/Models/Data/RotateUserAccountKeysData.cs @@ -1,6 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - + using Bit.Core.Auth.Entities; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; @@ -12,21 +10,19 @@ namespace Bit.Core.KeyManagement.Models.Data; public class RotateUserAccountKeysData { // Authentication for this requests - public string OldMasterKeyAuthenticationHash { get; set; } + public required string OldMasterKeyAuthenticationHash { get; set; } - // Other keys encrypted by the userkey - public string UserKeyEncryptedAccountPrivateKey { get; set; } - public string AccountPublicKey { get; set; } + public required UserAccountKeysData AccountKeys { get; set; } // All methods to get to the userkey - public MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; } - public IEnumerable EmergencyAccesses { get; set; } - public IReadOnlyList OrganizationUsers { get; set; } - public IEnumerable WebAuthnKeys { get; set; } - public IEnumerable DeviceKeys { get; set; } + public required MasterPasswordUnlockAndAuthenticationData MasterPasswordUnlockData { get; set; } + public required IEnumerable EmergencyAccesses { get; set; } + public required IReadOnlyList OrganizationUsers { get; set; } + public required IEnumerable WebAuthnKeys { get; set; } + public required IEnumerable DeviceKeys { get; set; } // User vault data encrypted by the userkey - public IEnumerable Ciphers { get; set; } - public IEnumerable Folders { get; set; } - public IReadOnlyList Sends { get; set; } + public required IEnumerable Ciphers { get; set; } + public required IEnumerable Folders { get; set; } + public required IReadOnlyList Sends { get; set; } } diff --git a/src/Core/KeyManagement/Models/Data/SecurityStateData.cs b/src/Core/KeyManagement/Models/Data/SecurityStateData.cs new file mode 100644 index 0000000000..c9a4610387 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/SecurityStateData.cs @@ -0,0 +1,10 @@ + +namespace Bit.Core.KeyManagement.Models.Data; + +public class SecurityStateData +{ + public required string SecurityState { get; set; } + // The security version is included in the security state, but needs COSE parsing, + // so this is a separate copy that can be used directly. + public required int SecurityVersion { get; set; } +} diff --git a/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs b/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs new file mode 100644 index 0000000000..32ae3eef8f --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/SignatureKeyPairData.cs @@ -0,0 +1,21 @@ + +using System.Text.Json.Serialization; +using Bit.Core.KeyManagement.Enums; + +namespace Bit.Core.KeyManagement.Models.Data; + +public class SignatureKeyPairData +{ + public required SignatureAlgorithm SignatureAlgorithm { get; set; } + public required string WrappedSigningKey { get; set; } + public required string VerifyingKey { get; set; } + + [JsonConstructor] + [System.Diagnostics.CodeAnalysis.SetsRequiredMembersAttribute] + public SignatureKeyPairData(SignatureAlgorithm signatureAlgorithm, string wrappedSigningKey, string verifyingKey) + { + SignatureAlgorithm = signatureAlgorithm; + WrappedSigningKey = wrappedSigningKey ?? throw new ArgumentNullException(nameof(wrappedSigningKey)); + VerifyingKey = verifyingKey ?? throw new ArgumentNullException(nameof(verifyingKey)); + } +} diff --git a/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs new file mode 100644 index 0000000000..cabdca59ea --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/UserAccountKeysData.cs @@ -0,0 +1,9 @@ +namespace Bit.Core.KeyManagement.Models.Data; + + +public class UserAccountKeysData +{ + public required PublicKeyEncryptionKeyPairData PublicKeyEncryptionKeyPairData { get; set; } + public SignatureKeyPairData? SignatureKeyPairData { get; set; } + public SecurityStateData? SecurityStateData { get; set; } +} diff --git a/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs b/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs new file mode 100644 index 0000000000..4ea9b7582b --- /dev/null +++ b/src/Core/KeyManagement/Queries/Interfaces/IUserAcountKeysQuery.cs @@ -0,0 +1,10 @@ + +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Queries.Interfaces; + +public interface IUserAccountKeysQuery +{ + Task Run(User user); +} diff --git a/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs b/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs new file mode 100644 index 0000000000..7aafd2cf1e --- /dev/null +++ b/src/Core/KeyManagement/Queries/UserAccountKeysQuery.cs @@ -0,0 +1,35 @@ + +using Bit.Core.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; +using Bit.Core.KeyManagement.Repositories; + +namespace Bit.Core.KeyManagement.Queries; + + +public class UserAccountKeysQuery(IUserSignatureKeyPairRepository signatureKeyPairRepository) : IUserAccountKeysQuery +{ + public async Task Run(User user) + { + if (user.GetSecurityVersion() < 2) + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + }; + } + else + { + return new UserAccountKeysData + { + PublicKeyEncryptionKeyPairData = user.GetPublicKeyEncryptionKeyPair(), + SignatureKeyPairData = await signatureKeyPairRepository.GetByUserIdAsync(user.Id), + SecurityStateData = new SecurityStateData + { + SecurityState = user.SecurityState!, + SecurityVersion = user.GetSecurityVersion(), + } + }; + } + } +} diff --git a/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs b/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..ce8979620f --- /dev/null +++ b/src/Core/KeyManagement/Repositories/IUserSignatureKeyPairRepository.cs @@ -0,0 +1,14 @@ + +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Repositories; + +namespace Bit.Core.KeyManagement.Repositories; + +public interface IUserSignatureKeyPairRepository : IRepository +{ + public Task GetByUserIdAsync(Guid userId); + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signatureKeyPair); + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signatureKeyPair); +} diff --git a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs index 011fc2932f..c1e7905d78 100644 --- a/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs +++ b/src/Core/KeyManagement/UserKey/Implementations/RotateUserAccountkeysCommand.cs @@ -1,6 +1,11 @@ -using Bit.Core.Auth.Repositories; +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using Bit.Core.Auth.Repositories; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -25,6 +30,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand private readonly IdentityErrorDescriber _identityErrorDescriber; private readonly IWebAuthnCredentialRepository _credentialRepository; private readonly IPasswordHasher _passwordHasher; + private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository; private readonly IFeatureService _featureService; /// @@ -37,16 +43,19 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand /// Provides a method to update re-encrypted send data /// Provides a method to update re-encrypted emergency access data /// Provides a method to update re-encrypted organization user data + /// Provides a method to update re-encrypted device keys /// Hashes the new master password /// Logs out user from other devices after successful rotation /// Provides a password mismatch error if master password hash validation fails /// Provides a method to update re-encrypted WebAuthn keys + /// Provides a method to update re-encrypted signature keys public RotateUserAccountKeysCommand(IUserService userService, IUserRepository userRepository, ICipherRepository cipherRepository, IFolderRepository folderRepository, ISendRepository sendRepository, IEmergencyAccessRepository emergencyAccessRepository, IOrganizationUserRepository organizationUserRepository, IDeviceRepository deviceRepository, IPasswordHasher passwordHasher, IPushNotificationService pushService, IdentityErrorDescriber errors, IWebAuthnCredentialRepository credentialRepository, + IUserSignatureKeyPairRepository userSignatureKeyPairRepository, IFeatureService featureService) { _userService = userService; @@ -61,6 +70,7 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand _identityErrorDescriber = errors; _credentialRepository = credentialRepository; _passwordHasher = passwordHasher; + _userSignatureKeyPairRepository = userSignatureKeyPairRepository; _featureService = featureService; } @@ -82,58 +92,106 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand user.LastKeyRotationDate = now; user.SecurityStamp = Guid.NewGuid().ToString(); - if ( - !model.MasterPasswordUnlockData.ValidateForUser(user) - ) + List saveEncryptedDataActions = []; + + await UpdateAccountKeysAsync(model, user, saveEncryptedDataActions); + UpdateUnlockMethods(model, user, saveEncryptedDataActions); + UpdateUserData(model, user, saveEncryptedDataActions); + + await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); + await _pushService.PushLogOutAsync(user.Id); + return IdentityResult.Success; + } + + public async Task RotateV2AccountKeysAsync(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidateV2Encryption(model); + await ValidateVerifyingKeyUnchangedAsync(model, user); + + saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.UpdateForKeyRotation(user.Id, model.AccountKeys.SignatureKeyPairData)); + user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState; + user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion; + } + + public void UpgradeV1ToV2Keys(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidateV2Encryption(model); + saveEncryptedDataActions.Add(_userSignatureKeyPairRepository.SetUserSignatureKeyPair(user.Id, model.AccountKeys.SignatureKeyPairData)); + user.SignedPublicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey; + user.SecurityState = model.AccountKeys.SecurityStateData!.SecurityState; + user.SecurityVersion = model.AccountKeys.SecurityStateData.SecurityVersion; + } + + public async Task UpdateAccountKeysAsync(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + ValidatePublicKeyEncryptionKeyPairUnchanged(model, user); + + if (IsV2EncryptionUserAsync(user)) { - throw new InvalidOperationException("The provided master password unlock data is not valid for this user."); + await RotateV2AccountKeysAsync(model, user, saveEncryptedDataActions); } - if ( - model.AccountPublicKey != user.PublicKey - ) + else if (model.AccountKeys.SignatureKeyPairData != null) { - throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric keypair is currently not supported during key rotation."); + UpgradeV1ToV2Keys(model, user, saveEncryptedDataActions); + } + else + { + if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.AesCbc256_HmacSha256_B64) + { + throw new InvalidOperationException("The provided account private key was not wrapped with AES-256-CBC-HMAC"); + } + // V1 user to V1 user rotation needs to further changes, the private key was re-encrypted. } - user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey; - user.PrivateKey = model.UserKeyEncryptedAccountPrivateKey; - user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash); - user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint; + // Private key is re-wrapped with new user key by client + user.PrivateKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey; + } + + public void UpdateUserData(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + // The revision date has to be updated so that de-synced clients don't accidentally post over the re-encrypted data + // with an old-user key-encrypted copy + var now = DateTime.UtcNow; - List saveEncryptedDataActions = new(); if (model.Ciphers.Any()) { - var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); - if (useBulkResourceCreationService) - { - saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation_vNext(user.Id, model.Ciphers)); - } - else - { - saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, model.Ciphers)); - } + var ciphersWithUpdatedDate = model.Ciphers.ToList().Select(c => { c.RevisionDate = now; return c; }); + saveEncryptedDataActions.Add(_cipherRepository.UpdateForKeyRotation(user.Id, ciphersWithUpdatedDate)); } if (model.Folders.Any()) { - saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, model.Folders)); + var foldersWithUpdatedDate = model.Folders.ToList().Select(f => { f.RevisionDate = now; return f; }); + saveEncryptedDataActions.Add(_folderRepository.UpdateForKeyRotation(user.Id, foldersWithUpdatedDate)); } if (model.Sends.Any()) { - saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, model.Sends)); + var sendsWithUpdatedDate = model.Sends.ToList().Select(s => { s.RevisionDate = now; return s; }); + saveEncryptedDataActions.Add(_sendRepository.UpdateForKeyRotation(user.Id, sendsWithUpdatedDate)); } + } + + void UpdateUnlockMethods(RotateUserAccountKeysData model, User user, List saveEncryptedDataActions) + { + if (!model.MasterPasswordUnlockData.ValidateForUser(user)) + { + throw new InvalidOperationException("The provided master password unlock data is not valid for this user."); + } + // Update master password authentication & unlock + user.Key = model.MasterPasswordUnlockData.MasterKeyEncryptedUserKey; + user.MasterPassword = _passwordHasher.HashPassword(user, model.MasterPasswordUnlockData.MasterKeyAuthenticationHash); + user.MasterPasswordHint = model.MasterPasswordUnlockData.MasterPasswordHint; if (model.EmergencyAccesses.Any()) { - saveEncryptedDataActions.Add( - _emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses)); + saveEncryptedDataActions.Add(_emergencyAccessRepository.UpdateForKeyRotation(user.Id, model.EmergencyAccesses)); } if (model.OrganizationUsers.Any()) { - saveEncryptedDataActions.Add( - _organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers)); + saveEncryptedDataActions.Add(_organizationUserRepository.UpdateForKeyRotation(user.Id, model.OrganizationUsers)); } if (model.WebAuthnKeys.Any()) @@ -145,9 +203,80 @@ public class RotateUserAccountKeysCommand : IRotateUserAccountKeysCommand { saveEncryptedDataActions.Add(_deviceRepository.UpdateKeysForRotationAsync(user.Id, model.DeviceKeys)); } + } - await _userRepository.UpdateUserKeyAndEncryptedDataV2Async(user, saveEncryptedDataActions); - await _pushService.PushLogOutAsync(user.Id); - return IdentityResult.Success; + private bool IsV2EncryptionUserAsync(User user) + { + // Returns whether the user is a V2 user based on the private key's encryption type. + ArgumentNullException.ThrowIfNull(user); + var isPrivateKeyEncryptionV2 = GetEncryptionType(user.PrivateKey) == EncryptionType.XChaCha20Poly1305_B64; + return isPrivateKeyEncryptionV2; + } + + private async Task ValidateVerifyingKeyUnchangedAsync(RotateUserAccountKeysData model, User user) + { + var currentSignatureKeyPair = await _userSignatureKeyPairRepository.GetByUserIdAsync(user.Id) ?? throw new InvalidOperationException("User does not have a signature key pair."); + if (model.AccountKeys.SignatureKeyPairData.VerifyingKey != currentSignatureKeyPair!.VerifyingKey) + { + throw new InvalidOperationException("The provided verifying key does not match the user's current verifying key."); + } + } + + private static void ValidatePublicKeyEncryptionKeyPairUnchanged(RotateUserAccountKeysData model, User user) + { + var publicKey = model.AccountKeys.PublicKeyEncryptionKeyPairData.PublicKey; + if (publicKey != user.PublicKey) + { + throw new InvalidOperationException("The provided account public key does not match the user's current public key, and changing the account asymmetric key pair is currently not supported during key rotation."); + } + } + + private static void ValidateV2Encryption(RotateUserAccountKeysData model) + { + if (model.AccountKeys.SignatureKeyPairData == null) + { + throw new InvalidOperationException("Signature key pair data is required for V2 encryption."); + } + if (GetEncryptionType(model.AccountKeys.SignatureKeyPairData.WrappedSigningKey) != EncryptionType.XChaCha20Poly1305_B64) + { + throw new InvalidOperationException("The provided signing key data is not wrapped with XChaCha20-Poly1305."); + } + if (string.IsNullOrEmpty(model.AccountKeys.SignatureKeyPairData.VerifyingKey)) + { + throw new InvalidOperationException("The provided signature key pair data does not contain a valid verifying key."); + } + + if (GetEncryptionType(model.AccountKeys.PublicKeyEncryptionKeyPairData.WrappedPrivateKey) != EncryptionType.XChaCha20Poly1305_B64) + { + throw new InvalidOperationException("The provided private key encryption key is not wrapped with XChaCha20-Poly1305."); + } + if (string.IsNullOrEmpty(model.AccountKeys.PublicKeyEncryptionKeyPairData.SignedPublicKey)) + { + throw new InvalidOperationException("No signed public key provided, but the user already has a signature key pair."); + } + if (model.AccountKeys.SecurityStateData == null || string.IsNullOrEmpty(model.AccountKeys.SecurityStateData.SecurityState)) + { + throw new InvalidOperationException("No signed security state provider for V2 user"); + } + } + + /// + /// Helper method to convert an encryption type string to an enum value. + /// + private static EncryptionType GetEncryptionType(string encString) + { + var parts = encString.Split('.'); + if (parts.Length == 1) + { + throw new ArgumentException("Invalid encryption type string."); + } + if (byte.TryParse(parts[0], out var encryptionTypeNumber)) + { + if (Enum.IsDefined(typeof(EncryptionType), encryptionTypeNumber)) + { + return (EncryptionType)encryptionTypeNumber; + } + } + throw new ArgumentException("Invalid encryption type string."); } } diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs new file mode 100644 index 0000000000..fad0af840d --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs @@ -0,0 +1,682 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + +
+ + + + + + + +
+ + + + + + + + +
+ + + + + +
+ + + + + + + +
+ + +
+ + + + + + + + + + + + + +
+ + + + + + + +
+ + + +
+ +
+ +

+ Verify your email to access this Bitwarden Send +

+ +
+ +
+ + + +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ +
+ +
+ + +
+ +
+ + + + + +
+ + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ + + + + + + +
+ + +
+ + + + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
+ +
Your verification code is:
+ +
+ +
{{Token}}
+ +
+ +
+ +
+ +
This code expires in {{Expiry}} minutes. After that, you'll need to + verify your email again.
+ +
+ +
+ +
+ + +
+ +
+ + + + + +
+ + + + + + + +
+ + +
+ + + + + + + +
+ + + + + + + + + +
+ +

+ Bitwarden Send transmits sensitive, temporary information to + others easily and securely. Learn more about + Bitwarden Send + or + sign up + to try it today. +

+ +
+ +
+ +
+ + +
+ +
+ + + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ + + + + + + +
+ + +
+ + + + + + + + + +
+ +

+ Learn more about Bitwarden +

+ Find user guides, product documentation, and videos on the + Bitwarden Help Center.
+ +
+ +
+ + + +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ +
+ +
+ + +
+ +
+ + + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + +
+ + + + + + + + + + + + + +
+ + + + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + +
+ +

+ © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

+

+ Always confirm you are on a trusted Bitwarden domain before logging + in:
+ bitwarden.com | + Learn why we include this +

+ +
+ +
+ + +
+ +
+ + + + + +
+ + + diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs new file mode 100644 index 0000000000..7c9c1db527 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs @@ -0,0 +1,9 @@ +{{#>BasicTextLayout}} +Verify your email to access this Bitwarden Send. + +Your verification code is: {{Token}} + +This code can only be used once and expires in {{Expiry}} minutes. After that you'll need to verify your email again. + +Bitwarden Send transmits sensitive, temporary information to others easily and securely. Learn more about Bitwarden Send or sign up to try it today. +{{/BasicTextLayout}} diff --git a/src/Core/MailTemplates/Mjml/.mjmlconfig b/src/Core/MailTemplates/Mjml/.mjmlconfig index 7560e0fb96..92734a5f71 100644 --- a/src/Core/MailTemplates/Mjml/.mjmlconfig +++ b/src/Core/MailTemplates/Mjml/.mjmlconfig @@ -1,5 +1,7 @@ { "packages": [ - "components/hero" + "components/mj-bw-hero", + "components/mj-bw-icon-row", + "components/mj-bw-learn-more-footer" ] } diff --git a/src/Core/MailTemplates/Mjml/README.md b/src/Core/MailTemplates/Mjml/README.md index b60655140a..b9041c94f6 100644 --- a/src/Core/MailTemplates/Mjml/README.md +++ b/src/Core/MailTemplates/Mjml/README.md @@ -1,19 +1,123 @@ -# Email templates +# MJML email templating -This directory contains MJML templates for emails sent by the application. MJML is a markup language designed to reduce the pain of coding responsive email templates. +This directory contains MJML templates for emails. MJML is a markup language designed to reduce the pain of coding responsive email templates. Component based development features in MJML improve code quality and reusability. -## Usage +MJML stands for MailJet Markup Language. -```bash +## Implementation considerations + +These `MJML` templates are compiled into HTML which will then be further consumed by our Handlebars mail service. We can continue to use this service to assign values from our View Models. This leverages the existing infrastructure. It also means we can continue to use the double brace (`{{}}`) syntax within MJML since Handlebars can be used to assign values to those `{{variables}}`. + +There is no change on how we interact with our view models. + +There is an added step where we compile `*.mjml` to `*.html.hbs`. `*.html.hbs` is the format we use so the handlebars service can apply the variables. This build pipeline process is in progress and may need to be manually done at times. + +### `*.txt.hbs` + +There is no change to how we create the `txt.hbs`. MJML does not impact how we create these artifacts. + +## Building `MJML` files + +```shell npm ci -# Build once +# Build *.html to ./out directory npm run build -# To build on changes -npm run watch +# To build on changes to *.mjml and *.js files, new files will not be tracked, you will need to run again +npm run build:watch + +# Build *.html.hbs to ./out directory +npm run build:hbs + +# Build minified *.html.hbs to ./out directory +npm run build:minify + +# apply prettier formatting +npm run prettier ``` ## Development -MJML supports components and you can create your own components by adding them to `.mjmlconfig`. +MJML supports components and you can create your own components by adding them to `.mjmlconfig`. Components are simple JavaScript that return MJML markup based on the attributes assigned, see components/mj-bw-hero.js. The markup is not a proper object, but contained in a string. + +When using MJML templating you can use the above [commands](#building-mjml-files) to compile the template and view it in a web browser. + +Not all MJML tags have the same attributes, it is highly recommended to review the documentation on the official MJML website to understand the usages of each of the tags. + +### Recommended development - IMailService + +#### Mjml email template development + +1. create `cool-email.mjml` in appropriate team directory +2. run `npm run build:watch` +3. view compiled `HTML` output in a web browser +4. iterate -> while `build:watch`'ing you should be able to refresh the browser page after the mjml/js re-compile to see the changes + +#### Testing with `IMailService` + +After the email is developed from the [initial step](#mjml-email-template-development) make sure the email `{{variables}}` are populated properly by running it through an `IMailService` implementation. + +1. run `npm run build:hbs` +2. copy built `*.html.hbs` files from the build directory to a location the mail service can consume them + 1. all files in the `Core/MailTemplates/Mjml/out` directory can be copied to the `src/Core/MailTemplates/Handlebars/MJML` directory. If a shared component is modified it is important to copy and overwrite all files in that directory to capture + changes in the `*.html.hbs`. +3. run code that will send the email + +The minified `html.hbs` artifacts are deliverables and must be placed into the correct `src/Core/MailTemplates/Handlebars/` directories in order to be used by `IMailService` implementations, see 2.1 above. + +### Recommended development - IMailer + +TBD - PM-26475 + +### Custom tags + +There is currently a `mj-bw-hero` tag you can use within your `*.mjml` templates. This is a good example of how to create a component that takes in attribute values allowing us to be more DRY in our development of emails. Since the attribute's input is a string we are able to define whatever we need into the component, in this case `mj-bw-hero`. + +In order to view the custom component you have written you will need to include it in the `.mjmlconfig` and reference it in an `mjml` template file. + +```html + + +``` + +Attributes in Custom Components are defined by the developer. They can be required or optional depending on implementation. See the official MJML documentation for more information. + +```js +static allowedAttributes = { + "img-src": "string", // REQUIRED: Source for the image displayed in the right-hand side of the blue header area + title: "string", // REQUIRED: large text stating primary purpose of the email + "button-text": "string", // OPTIONAL: text to display in the button + "button-url": "string", // OPTIONAL: URL to navigate to when the button is clicked + "sub-title": "string", // OPTIONAL: smaller text providing additional context for the title +}; + +static defaultAttributes = {}; +``` + +Custom components, such as `mj-bw-hero`, must be defined in the `.mjmlconfig` in order for them to be compiled and rendered properly in the templates. + +```json +{ + "packages": ["components/mj-bw-hero"] +} +``` + +### `mj-include` + +You are also able to reference other more static MJML templates in your MJML file simply by referencing the file within the MJML template. + +```html + + + + +``` + +#### `head.mjml` +Currently we include the `head.mjml` file in all MJML templates as it contains shared styling and formatting that ensures consistency across all email implementations. + +In the future we may deviate from this practice to support different layouts. At that time we will modify the docs with direction. diff --git a/src/Core/MailTemplates/Mjml/build.js b/src/Core/MailTemplates/Mjml/build.js new file mode 100644 index 0000000000..db8a7fe433 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/build.js @@ -0,0 +1,128 @@ +const mjml2html = require("mjml"); +const { registerComponent } = require("mjml-core"); +const fs = require("fs"); +const path = require("path"); +const glob = require("glob"); + +// Parse command line arguments +const args = process.argv.slice(2); // Remove 'node' and script path + +// Parse flags +const flags = { + minify: args.includes("--minify") || args.includes("-m"), + watch: args.includes("--watch") || args.includes("-w"), + hbs: args.includes("--hbs") || args.includes("-h"), + trace: args.includes("--trace") || args.includes("-t"), + clean: args.includes("--clean") || args.includes("-c"), + help: args.includes("--help"), +}; + +// Use __dirname to get absolute paths relative to the script location +const config = { + inputDir: path.join(__dirname, "emails"), + outputDir: path.join(__dirname, "out"), + minify: flags.minify, + validationLevel: "strict", + hbsOutput: flags.hbs, +}; + +// Debug output +if (flags.trace) { + console.log("[DEBUG] Script location:", __dirname); + console.log("[DEBUG] Input directory:", config.inputDir); + console.log("[DEBUG] Output directory:", config.outputDir); +} + +// Ensure output directory exists +if (!fs.existsSync(config.outputDir)) { + fs.mkdirSync(config.outputDir, { recursive: true }); + if (flags.trace) { + console.log("[INFO] Created output directory:", config.outputDir); + } +} + +// Find all MJML files with absolute path +const mjmlFiles = glob.sync(`${config.inputDir}/**/*.mjml`); + +console.log(`\n[INFO] Found ${mjmlFiles.length} MJML file(s) to compile...`); + +if (mjmlFiles.length === 0) { + console.error("[ERROR] No MJML files found!"); + console.error("[ERROR] Looked in:", config.inputDir); + console.error( + "[ERROR] Does this directory exist?", + fs.existsSync(config.inputDir), + ); + process.exit(1); +} + +// Compile each MJML file +let successCount = 0; +let errorCount = 0; + +mjmlFiles.forEach((filePath) => { + try { + const mjmlContent = fs.readFileSync(filePath, "utf8"); + const fileName = path.basename(filePath, ".mjml"); + const relativePath = path.relative(config.inputDir, filePath); + + console.log(`\n[BUILD] Compiling: ${relativePath}`); + + // Compile MJML to HTML + const result = mjml2html(mjmlContent, { + minify: config.minify, + validationLevel: config.validationLevel, + filePath: filePath, // Important: tells MJML where the file is for resolving includes + mjmlConfigPath: __dirname, // Point to the directory with .mjmlconfig + }); + + // Check for errors + if (result.errors.length > 0) { + console.error(`[ERROR] Failed to compile ${fileName}.mjml:`); + result.errors.forEach((err) => + console.error(` ${err.formattedMessage}`), + ); + errorCount++; + return; + } + + // Calculate output path preserving directory structure + const relativeDir = path.dirname(relativePath); + const outputDir = path.join(config.outputDir, relativeDir); + + // Ensure subdirectory exists + if (!fs.existsSync(outputDir)) { + fs.mkdirSync(outputDir, { recursive: true }); + } + + const outputExtension = config.hbsOutput ? ".html.hbs" : ".html"; + const outputPath = path.join(outputDir, `${fileName}${outputExtension}`); + fs.writeFileSync(outputPath, result.html); + + console.log( + `[OK] Built: ${fileName}.mjml → ${path.relative(__dirname, outputPath)}`, + ); + successCount++; + + // Log warnings if any + if (result.warnings && result.warnings.length > 0) { + console.warn(`[WARN] Warnings for ${fileName}.mjml:`); + result.warnings.forEach((warn) => + console.warn(` ${warn.formattedMessage}`), + ); + } + } catch (error) { + console.error(`[ERROR] Exception processing ${path.basename(filePath)}:`); + console.error(` ${error.message}`); + errorCount++; + } +}); + +console.log(`\n[SUMMARY] Compilation complete!`); +console.log(` Success: ${successCount}`); +console.log(` Failed: ${errorCount}`); +console.log(` Output: ${config.outputDir}`); + +if (errorCount > 0) { + process.exit(1); +} diff --git a/src/Core/MailTemplates/Mjml/build.sh b/src/Core/MailTemplates/Mjml/build.sh deleted file mode 100755 index c76bdd8f61..0000000000 --- a/src/Core/MailTemplates/Mjml/build.sh +++ /dev/null @@ -1,4 +0,0 @@ -# TODO: This should probably be replaced with a node script building every file in `emails/` - -npx mjml emails/invite.mjml -o out/invite.html -npx mjml emails/two-factor.mjml -o out/two-factor.html diff --git a/src/Core/MailTemplates/Mjml/components/footer.mjml b/src/Core/MailTemplates/Mjml/components/footer.mjml index 0634033618..2b2268f33b 100644 --- a/src/Core/MailTemplates/Mjml/components/footer.mjml +++ b/src/Core/MailTemplates/Mjml/components/footer.mjml @@ -2,38 +2,38 @@ @@ -45,8 +45,8 @@

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

diff --git a/src/Core/MailTemplates/Mjml/components/head.mjml b/src/Core/MailTemplates/Mjml/components/head.mjml index 929057fb70..4cb27889eb 100644 --- a/src/Core/MailTemplates/Mjml/components/head.mjml +++ b/src/Core/MailTemplates/Mjml/components/head.mjml @@ -4,13 +4,21 @@ font-size="16px" /> - + - .link { text-decoration: none; color: #175ddc; font-weight: 600 } + .link { + text-decoration: none; + color: #175ddc; + font-weight: 600; + } - .border-fix > table { border-collapse:separate !important; } .border-fix > - table > tbody > tr > td { border-radius: 3px; } + .border-fix > table { + border-collapse: separate !important; + } + .border-fix > table > tbody > tr > td { + border-radius: 3px; + } diff --git a/src/Core/MailTemplates/Mjml/components/hero.js b/src/Core/MailTemplates/Mjml/components/mj-bw-hero.js similarity index 51% rename from src/Core/MailTemplates/Mjml/components/hero.js rename to src/Core/MailTemplates/Mjml/components/mj-bw-hero.js index 6c5bd9bc99..c7a3b7e7ff 100644 --- a/src/Core/MailTemplates/Mjml/components/hero.js +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-hero.js @@ -9,20 +9,50 @@ class MjBwHero extends BodyComponent { }; static allowedAttributes = { - "img-src": "string", - title: "string", - "button-text": "string", - "button-url": "string", + "img-src": "string", // REQUIRED: Source for the image displayed in the right-hand side of the blue header area + title: "string", // REQUIRED: large text stating primary purpose of the email + "button-text": "string", // OPTIONAL: text to display in the button + "button-url": "string", // OPTIONAL: URL to navigate to when the button is clicked + "sub-title": "string", // OPTIONAL: smaller text providing additional context for the title }; static defaultAttributes = {}; + componentHeadStyle = breakpoint => { + return ` + @media only screen and (max-width:${breakpoint}) { + .mj-bw-hero-responsive-img { + display: none !important; + } + } + ` + } + render() { - return this.renderMJML(` + const buttonElement = this.getAttribute("button-text") && this.getAttribute("button-url") ? + ` + ${this.getAttribute("button-text")} + ` : ""; + const subTitleElement = this.getAttribute("sub-title") ? + ` +

+ ${this.getAttribute("sub-title")} +

+
` : ""; + + return this.renderMJML( + ` ${this.getAttribute("title")} -
- - ${this.getAttribute("button-text")} - + ` + + subTitleElement + + ` +
` + + buttonElement + + ` + width="155px" + padding="0px" + css-class="mj-bw-hero-responsive-img" + /> - `); + `, + ); } } diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js new file mode 100644 index 0000000000..f7f402c96e --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-icon-row.js @@ -0,0 +1,100 @@ +const { BodyComponent } = require("mjml-core"); +class MjBwIconRow extends BodyComponent { + static dependencies = { + "mj-column": ["mj-bw-icon-row"], + "mj-wrapper": ["mj-bw-icon-row"], + "mj-bw-icon-row": [], + }; + + static allowedAttributes = { + "icon-src": "string", + "icon-alt": "string", + "head-url-text": "string", + "head-url": "string", + text: "string", + "foot-url-text": "string", + "foot-url": "string", + }; + + static defaultAttributes = {}; + + componentHeadStyle = (breakpoint) => { + return ` + @media only screen and (max-width:${breakpoint}): { + ".mj-bw-icon-row-text": { + padding-left: "5px !important", + line-height: "20px", + }, + ".mj-bw-icon-row": { + padding: "10px 15px", + width: "fit-content !important", + } + } + `; + }; + + render() { + const headAnchorElement = + this.getAttribute("head-url-text") && this.getAttribute("head-url") + ? ` + ${this.getAttribute("head-url-text")} + + External Link Icon + + ` + : ""; + + const footAnchorElement = + this.getAttribute("foot-url-text") && this.getAttribute("foot-url") + ? ` + ${this.getAttribute("foot-url-text")} + + External Link Icon + + ` + : ""; + + return this.renderMJML( + ` + + + + + + + + ` + + headAnchorElement + + ` + + + ${this.getAttribute("text")} + + + ` + + footAnchorElement + + ` + + + + + `, + ); + } +} + +module.exports = MjBwIconRow; diff --git a/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js b/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js new file mode 100644 index 0000000000..7dc2185995 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/components/mj-bw-learn-more-footer.js @@ -0,0 +1,51 @@ +const { BodyComponent } = require("mjml-core"); +class MjBwLearnMoreFooter extends BodyComponent { + static dependencies = { + // Tell the validator which tags are allowed as our component's parent + "mj-column": ["mj-bw-learn-more-footer"], + "mj-wrapper": ["mj-bw-learn-more-footer"], + // Tell the validator which tags are allowed as our component's children + "mj-bw-learn-more-footer": [], + }; + + static allowedAttributes = {}; + + static defaultAttributes = {}; + + componentHeadStyle = (breakpoint) => { + return ` + @media only screen and (max-width:${breakpoint}) { + .mj-bw-learn-more-footer-responsive-img { + display: none !important; + } + } + `; + }; + + render() { + return this.renderMJML( + ` + + + +

+ Learn more about Bitwarden +

+ Find user guides, product documentation, and videos on the + Bitwarden Help Center. +
+
+ + + +
+ `, + ); + } +} + +module.exports = MjBwLearnMoreFooter; diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml new file mode 100644 index 0000000000..86de49016d --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-family-user.mjml @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + An administrator from {{OrganizationName}} will approve you + before you can share passwords. While you wait for approval, get + started with Bitwarden Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml new file mode 100644 index 0000000000..e071cd26cc --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-free-user.mjml @@ -0,0 +1,59 @@ + + + + + + + + + + + + + + + + + Follow these simple steps to get up and running with Bitwarden + Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml new file mode 100644 index 0000000000..39f18fce66 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/Onboarding/welcome-org-user.mjml @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + An administrator from {{OrganizationName}} will need to confirm + you before you can share passwords. Get started with Bitwarden + Password Manager: + + + + + + + + + + + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml new file mode 100644 index 0000000000..d3d4eb9891 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/send-email-otp.mjml @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + Your verification code is: + {{Token}} + + + This code expires in {{Expiry}} minutes. After that, you'll need to + verify your email again. + + + + + + +

+ Bitwarden Send transmits sensitive, temporary information to + others easily and securely. Learn more about + Bitwarden Send + or + sign up + to try it today. +

+
+
+
+
+ + + + + + + + +
+
diff --git a/src/Core/MailTemplates/Mjml/emails/two-factor.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml similarity index 61% rename from src/Core/MailTemplates/Mjml/emails/two-factor.mjml rename to src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml index b959ec1c8a..73d205ba57 100644 --- a/src/Core/MailTemplates/Mjml/emails/two-factor.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Auth/two-factor.mjml @@ -1,10 +1,10 @@ - + - + -

Your two-step verification code is: {{Token}}

+

+ Your two-step verification code is: {{ Token }} +

Use this code to complete logging in with Bitwarden.

- + + +
diff --git a/src/Core/MailTemplates/Mjml/emails/invite.mjml b/src/Core/MailTemplates/Mjml/emails/invite.mjml index 4eae12d0dc..8e08a6753a 100644 --- a/src/Core/MailTemplates/Mjml/emails/invite.mjml +++ b/src/Core/MailTemplates/Mjml/emails/invite.mjml @@ -22,26 +22,7 @@ - - - -

- We’re here for you! -

- If you have any questions, search the Bitwarden - Help - site or - contact us. -
-
- - - -
+ diff --git a/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml b/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml new file mode 100644 index 0000000000..c50a5d1292 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/invoice-upcoming.mjml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc. + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/package-lock.json b/src/Core/MailTemplates/Mjml/package-lock.json index a78405676f..df85185af9 100644 --- a/src/Core/MailTemplates/Mjml/package-lock.json +++ b/src/Core/MailTemplates/Mjml/package-lock.json @@ -14,7 +14,7 @@ }, "devDependencies": { "nodemon": "3.1.10", - "prettier": "3.5.3" + "prettier": "3.6.2" } }, "node_modules/@babel/runtime": { @@ -1564,9 +1564,9 @@ } }, "node_modules/prettier": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.5.3.tgz", - "integrity": "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw==", + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.6.2.tgz", + "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", "bin": { diff --git a/src/Core/MailTemplates/Mjml/package.json b/src/Core/MailTemplates/Mjml/package.json index c3690a2d73..f74279da7b 100644 --- a/src/Core/MailTemplates/Mjml/package.json +++ b/src/Core/MailTemplates/Mjml/package.json @@ -15,8 +15,10 @@ }, "homepage": "https://bitwarden.com", "scripts": { - "build": "./build.sh", - "watch": "nodemon --exec ./build.sh --watch ./components --watch ./emails --ext js,mjml", + "build": "node ./build.js", + "build:hbs": "node ./build.js --hbs", + "build:minify": "node ./build.js --hbs --minify", + "build:watch": "nodemon ./build.js --watch emails --watch components --ext mjml,js", "prettier": "prettier --cache --write ." }, "dependencies": { @@ -25,6 +27,6 @@ }, "devDependencies": { "nodemon": "3.1.10", - "prettier": "3.5.3" + "prettier": "3.6.2" } } diff --git a/src/Core/MailTemplates/README.md b/src/Core/MailTemplates/README.md new file mode 100644 index 0000000000..bd42b2a10f --- /dev/null +++ b/src/Core/MailTemplates/README.md @@ -0,0 +1,78 @@ +Email templating +================ + +We use MJML to generate the HTML that our mail services use to send emails to users. To accomplish this, we use different file types depending on which part of the email generation process we're working with. + +# File Types + +## `*.html.hbs` +These are the compiled HTML email templates that serve as the foundation for all HTML emails sent by the Bitwarden platform. They are generated from MJML source files and enhanced with Handlebars templating capabilities. + +### Generation Process +- **Source**: Built from `*.mjml` files in the `./mjml` directory. + - The MJML source acts as a toolkit for developers to generate HTML. It is the developers responsibility to generate the HTML and then ensure it is accessible to `IMailService` implementations. +- **Build Tool**: Generated via node build scripts: `npm run build`. + - The build script definitions can be viewed in the `Mjml/package.json` as well as in `Mjml/build.js`. +- **Output**: Cross-client compatible HTML with embedded CSS for maximum email client support +- **Template Engine**: Enhanced with Handlebars syntax for dynamic content injection + +### Handlebars Integration +The templates use Handlebars templating syntax for dynamic content replacement: + +```html + +

Welcome {{userName}}!

+

Your organization {{organizationName}} has invited you to join.

+Accept Invitation +``` + +**Variable Types:** +- **Simple Variables**: `{{userName}}`, `{{email}}`, `{{organizationName}}` + +### Email Service Integration +The `IMailService` consumes these templates through the following process: + +1. **Template Selection**: Service selects appropriate `.html.hbs` template based on email type +2. **Model Binding**: View model properties are mapped to Handlebars variables +3. **Compilation**: Handlebars engine processes variables and generates final HTML + +### Development Guidelines + +**Variable Naming:** +- Use camelCase for consistency: `{{userName}}`, `{{organizationName}}` +- Prefix URLs with descriptive names: `{{actionUrl}}`, `{{logoUrl}}` + +**Testing Considerations:** +- Verify Handlebars variable replacement with actual view model data +- Ensure graceful degradation when variables are missing or null, if necessary +- Validate HTML structure and accessibility compliance + +## `*.txt.hbs` +These files provide plain text versions of emails and are essential for email accessibility and deliverability. They serve several important purposes: + +### Purpose and Usage +- **Accessibility**: Screen readers and assistive technologies often work better with plain text versions +- **Email Client Compatibility**: Some email clients prefer or only display plain text versions +- **Fallback Content**: When HTML rendering fails, the plain text version ensures the message is still readable + +### Structure +Plain text email templates use the same Handlebars syntax (`{{variable}}`) as HTML templates for dynamic content replacement. They should: + +- Contain the core message content without HTML formatting +- Use line breaks and spacing for readability +- Include all important links as full URLs +- Maintain logical content hierarchy using spacing and simple text formatting + +### Email Service Integration +The `IMailService` automatically uses both versions when sending emails: +- The HTML version (from `*.html.hbs`) provides rich formatting and styling +- The plain text version (from `*.txt.hbs`) serves as the text alternative +- Email clients can choose which version to display based on user preferences and capabilities + +### Development Guidelines +- Always create a corresponding `*.txt.hbs` file for each `*.html.hbs` template +- Keep the content concise but complete - include all essential information from the HTML version +- Test plain text templates to ensure they're readable and convey the same message + +## `*.mjml` +This is a templating language we use to increase efficiency when creating email content. See the readme within the `./mjml` directory for more comprehensive information. diff --git a/src/Core/Models/Business/SubscriptionInfo.cs b/src/Core/Models/Business/SubscriptionInfo.cs index a016ac54f3..be514cb39f 100644 --- a/src/Core/Models/Business/SubscriptionInfo.cs +++ b/src/Core/Models/Business/SubscriptionInfo.cs @@ -1,52 +1,118 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using Bit.Core.Billing.Extensions; using Stripe; +#nullable enable + namespace Bit.Core.Models.Business; public class SubscriptionInfo { - public BillingCustomerDiscount CustomerDiscount { get; set; } - public BillingSubscription Subscription { get; set; } - public BillingUpcomingInvoice UpcomingInvoice { get; set; } + /// + /// Converts Stripe's minor currency units (cents) to major currency units (dollars). + /// IMPORTANT: Only supports USD. All Bitwarden subscriptions are USD-only. + /// + private const decimal StripeMinorUnitDivisor = 100M; + /// + /// Converts Stripe's minor currency units (cents) to major currency units (dollars). + /// Preserves null semantics to distinguish between "no amount" (null) and "zero amount" (0.00m). + /// + /// The amount in Stripe's minor currency units (e.g., cents for USD). + /// The amount in major currency units (e.g., dollars for USD), or null if the input is null. + private static decimal? ConvertFromStripeMinorUnits(long? amountInCents) + { + return amountInCents.HasValue ? amountInCents.Value / StripeMinorUnitDivisor : null; + } + + public BillingCustomerDiscount? CustomerDiscount { get; set; } + public BillingSubscription? Subscription { get; set; } + public BillingUpcomingInvoice? UpcomingInvoice { get; set; } + + /// + /// Represents customer discount information from Stripe billing. + /// public class BillingCustomerDiscount { public BillingCustomerDiscount() { } + /// + /// Creates a BillingCustomerDiscount from a Stripe Discount object. + /// + /// The Stripe discount containing coupon and expiration information. public BillingCustomerDiscount(Discount discount) { Id = discount.Coupon?.Id; + // Active = true only for perpetual/recurring discounts (no end date) + // This is intentional for Milestone 2 - only perpetual discounts are shown in UI Active = discount.End == null; PercentOff = discount.Coupon?.PercentOff; - AppliesTo = discount.Coupon?.AppliesTo?.Products ?? []; + AmountOff = ConvertFromStripeMinorUnits(discount.Coupon?.AmountOff); + // Stripe's CouponAppliesTo.Products is already IReadOnlyList, so no conversion needed + AppliesTo = discount.Coupon?.AppliesTo?.Products; } - public string Id { get; set; } + /// + /// The Stripe coupon ID (e.g., "cm3nHfO1"). + /// Note: Only specific coupon IDs are displayed in the UI based on feature flag configuration, + /// though Stripe may apply additional discounts that are not shown. + /// + public string? Id { get; set; } + + /// + /// True only for perpetual/recurring discounts (End == null). + /// False for any discount with an expiration date, even if not yet expired. + /// Product decision for Milestone 2: only show perpetual discounts in UI. + /// public bool Active { get; set; } + + /// + /// Percentage discount applied to the subscription (e.g., 20.0 for 20% off). + /// Null if this is an amount-based discount. + /// public decimal? PercentOff { get; set; } - public List AppliesTo { get; set; } + + /// + /// Fixed amount discount in USD (e.g., 14.00 for $14 off). + /// Converted from Stripe's cent-based values (1400 cents → $14.00). + /// Null if this is a percentage-based discount. + /// + public decimal? AmountOff { get; set; } + + /// + /// List of Stripe product IDs that this discount applies to (e.g., ["prod_premium", "prod_families"]). + /// + /// Null: discount applies to all products with no restrictions (AppliesTo not specified in Stripe). + /// Empty list: discount restricted to zero products (edge case - AppliesTo.Products = [] in Stripe). + /// Non-empty list: discount applies only to the specified product IDs. + /// + /// + public IReadOnlyList? AppliesTo { get; set; } } public class BillingSubscription { public BillingSubscription(Subscription sub) { - Status = sub.Status; - TrialStartDate = sub.TrialStart; - TrialEndDate = sub.TrialEnd; - PeriodStartDate = sub.CurrentPeriodStart; - PeriodEndDate = sub.CurrentPeriodEnd; - CancelledDate = sub.CanceledAt; - CancelAtEndDate = sub.CancelAtPeriodEnd; - Cancelled = sub.Status == "canceled" || sub.Status == "unpaid" || sub.Status == "incomplete_expired"; - if (sub.Items?.Data != null) + Status = sub?.Status; + TrialStartDate = sub?.TrialStart; + TrialEndDate = sub?.TrialEnd; + var currentPeriod = sub?.GetCurrentPeriod(); + if (currentPeriod != null) + { + var (start, end) = currentPeriod.Value; + PeriodStartDate = start; + PeriodEndDate = end; + } + CancelledDate = sub?.CanceledAt; + CancelAtEndDate = sub?.CancelAtPeriodEnd ?? false; + var status = sub?.Status; + Cancelled = status == "canceled" || status == "unpaid" || status == "incomplete_expired"; + if (sub?.Items?.Data != null) { Items = sub.Items.Data.Select(i => new BillingSubscriptionItem(i)); } - CollectionMethod = sub.CollectionMethod; - GracePeriod = sub.CollectionMethod == "charge_automatically" + CollectionMethod = sub?.CollectionMethod; + GracePeriod = sub?.CollectionMethod == "charge_automatically" ? 14 : 30; } @@ -58,10 +124,10 @@ public class SubscriptionInfo public TimeSpan? PeriodDuration => PeriodEndDate - PeriodStartDate; public DateTime? CancelledDate { get; set; } public bool CancelAtEndDate { get; set; } - public string Status { get; set; } + public string? Status { get; set; } public bool Cancelled { get; set; } public IEnumerable Items { get; set; } = new List(); - public string CollectionMethod { get; set; } + public string? CollectionMethod { get; set; } public DateTime? SuspensionDate { get; set; } public DateTime? UnpaidPeriodEndDate { get; set; } public int GracePeriod { get; set; } @@ -74,7 +140,7 @@ public class SubscriptionInfo { ProductId = item.Plan.ProductId; Name = item.Plan.Nickname; - Amount = item.Plan.Amount.GetValueOrDefault() / 100M; + Amount = ConvertFromStripeMinorUnits(item.Plan.Amount) ?? 0; Interval = item.Plan.Interval; if (item.Metadata != null) @@ -84,15 +150,15 @@ public class SubscriptionInfo } Quantity = (int)item.Quantity; - SponsoredSubscriptionItem = Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); + SponsoredSubscriptionItem = item.Plan != null && Utilities.StaticStore.SponsoredPlans.Any(p => p.StripePlanId == item.Plan.Id); } public bool AddonSubscriptionItem { get; set; } - public string ProductId { get; set; } - public string Name { get; set; } + public string? ProductId { get; set; } + public string? Name { get; set; } public decimal Amount { get; set; } public int Quantity { get; set; } - public string Interval { get; set; } + public string? Interval { get; set; } public bool SponsoredSubscriptionItem { get; set; } } } @@ -103,7 +169,7 @@ public class SubscriptionInfo public BillingUpcomingInvoice(Invoice inv) { - Amount = inv.AmountDue / 100M; + Amount = ConvertFromStripeMinorUnits(inv.AmountDue) ?? 0; Date = inv.Created; } diff --git a/src/Core/Models/Business/SubscriptionUpdate.cs b/src/Core/Models/Business/SubscriptionUpdate.cs index 028fcad80b..7c23c9b73c 100644 --- a/src/Core/Models/Business/SubscriptionUpdate.cs +++ b/src/Core/Models/Business/SubscriptionUpdate.cs @@ -50,6 +50,7 @@ public abstract class SubscriptionUpdate protected static bool IsNonSeatBasedPlan(StaticStore.Plan plan) => plan.Type is >= PlanType.FamiliesAnnually2019 and <= PlanType.EnterpriseAnnually2019 + or PlanType.FamiliesAnnually2025 or PlanType.FamiliesAnnually or PlanType.TeamsStarter2023 or PlanType.TeamsStarter; diff --git a/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs index 5faf550e60..5eabd5ba2c 100644 --- a/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs +++ b/src/Core/Models/Mail/Auth/DefaultEmailOtpViewModel.cs @@ -9,4 +9,5 @@ public class DefaultEmailOtpViewModel : BaseMailModel public string? TheDate { get; set; } public string? TheTime { get; set; } public string? TimeZone { get; set; } + public string? Expiry { get; set; } } diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs new file mode 100644 index 0000000000..aeca436dbb --- /dev/null +++ b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.cs @@ -0,0 +1,10 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Models.Mail.UpdatedInvoiceIncoming; + +public class UpdatedInvoiceUpcomingView : BaseMailView; + +public class UpdatedInvoiceUpcomingMail : BaseMail +{ + public override string Subject { get => "Your Subscription Will Renew Soon"; } +} diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs new file mode 100644 index 0000000000..a044171fe5 --- /dev/null +++ b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.html.hbs @@ -0,0 +1,30 @@ +
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc.

© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
bitwarden.com | Learn why we include this

\ No newline at end of file diff --git a/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs new file mode 100644 index 0000000000..a2db92bac2 --- /dev/null +++ b/src/Core/Models/Mail/UpdatedInvoiceIncoming/UpdatedInvoiceUpcomingView.text.hbs @@ -0,0 +1,3 @@ +{{#>BasicTextLayout}} + Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nunc semper sapien non sem tincidunt pretium ut vitae tortor. Mauris mattis id arcu in dictum. Vivamus tempor maximus elit id convallis. Pellentesque ligula nisl, bibendum eu maximus sit amet, rutrum efficitur tortor. Cras non dignissim leo, eget gravida odio. Nullam tincidunt porta fermentum. Fusce sit amet sagittis nunc. +{{/BasicTextLayout}} diff --git a/src/Core/Models/PushNotification.cs b/src/Core/Models/PushNotification.cs index c4ae1e2858..a622b98e05 100644 --- a/src/Core/Models/PushNotification.cs +++ b/src/Core/Models/PushNotification.cs @@ -97,3 +97,9 @@ public class ProviderBankAccountVerifiedPushNotification public Guid ProviderId { get; set; } public Guid AdminId { get; set; } } + +public class LogOutPushNotification +{ + public Guid UserId { get; set; } + public PushNotificationLogOutReason? Reason { get; set; } +} diff --git a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs b/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs deleted file mode 100644 index 34662ecdbb..0000000000 --- a/src/Core/Models/Stripe/StripeSubscriptionListOptions.cs +++ /dev/null @@ -1,51 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -namespace Bit.Core.Models.BitStripe; - -// Stripe's SubscriptionListOptions model has a complex input for date filters. -// It expects a dictionary, and has lots of validation rules around what can have a value and what can't. -// To simplify this a bit we are extending Stripe's model and using our own date inputs, and building the dictionary they expect JiT. -// ___ -// Our model also facilitates selecting all elements in a list, which is unsupported by Stripe's model. -public class StripeSubscriptionListOptions : Stripe.SubscriptionListOptions -{ - public DateTime? CurrentPeriodEndDate { get; set; } - public string CurrentPeriodEndRange { get; set; } = "lt"; - public bool SelectAll { get; set; } - public new Stripe.DateRangeOptions CurrentPeriodEnd - { - get - { - return CurrentPeriodEndDate.HasValue ? - new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - } : - null; - } - } - - public Stripe.SubscriptionListOptions ToStripeApiOptions() - { - var stripeApiOptions = (Stripe.SubscriptionListOptions)this; - - if (SelectAll) - { - stripeApiOptions.EndingBefore = null; - stripeApiOptions.StartingAfter = null; - } - - if (CurrentPeriodEndDate.HasValue) - { - stripeApiOptions.CurrentPeriodEnd = new Stripe.DateRangeOptions() - { - LessThan = CurrentPeriodEndRange == "lt" ? CurrentPeriodEndDate : null, - GreaterThan = CurrentPeriodEndRange == "gt" ? CurrentPeriodEndDate : null - }; - } - - return stripeApiOptions; - } -} diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index da05bc929c..8cfd0a8df1 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.OrganizationAuth; using Bit.Core.AdminConsole.OrganizationAuth.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.Groups; using Bit.Core.AdminConsole.OrganizationFeatures.Groups.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Import; @@ -133,6 +134,7 @@ public static class OrganizationServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs index 9a995a9cf0..965e0cf2a9 100644 --- a/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSponsorships/FamiliesForEnterprise/SelfHosted/SelfHostedSyncSponsorshipsCommand.cs @@ -62,7 +62,7 @@ public class SelfHostedSyncSponsorshipsCommand : BaseIdentityClientService, ISel .ToDictionary(i => i.SponsoringOrganizationUserId); if (!organizationSponsorshipsDict.Any()) { - _logger.LogInformation($"No existing sponsorships to sync for organization {organizationId}"); + _logger.LogInformation("No existing sponsorships to sync for organization {organizationId}", organizationId); return; } var syncedSponsorships = new List(); diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs index f7d6f0e5a2..d4e1b3cd8d 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpdateSecretsManagerSubscriptionCommand.cs @@ -226,7 +226,11 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs // Check minimum seats currently in use by the organization if (organization.SmSeats.Value > update.SmSeats.Value) { + // Retrieve the number of currently occupied Secrets Manager seats for the organization. var occupiedSeats = await _organizationUserRepository.GetOccupiedSmSeatCountByOrganizationIdAsync(organization.Id); + + // Check if the occupied number of seats exceeds the updated seat count. + // If so, throw an exception indicating that the subscription cannot be decreased below the current usage. if (occupiedSeats > update.SmSeats.Value) { throw new BadRequestException($"{occupiedSeats} users are currently occupying Secrets Manager seats. " + @@ -315,7 +319,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs throw new BadRequestException($"Cannot set max Secrets Manager seat autoscaling below current Secrets Manager seat count."); } - if (plan.SecretsManager.MaxSeats.HasValue && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats) + if (plan.SecretsManager.MaxSeats.HasValue && plan.SecretsManager.MaxSeats.Value > 0 && update.MaxAutoscaleSmSeats.Value > plan.SecretsManager.MaxSeats) { throw new BadRequestException(string.Concat( $"Your plan has a Secrets Manager seat limit of {plan.SecretsManager.MaxSeats}, ", @@ -412,7 +416,7 @@ public class UpdateSecretsManagerSubscriptionCommand : IUpdateSecretsManagerSubs } /// - /// Requests the number of Secret Manager seats and service accounts are currently used by the organization + /// Requests the number of Secret Manager seats and service accounts currently used by the organization /// /// The id of the organization /// A tuple containing the occupied seats and the occupied service account counts diff --git a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs similarity index 99% rename from src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs index 344c2e712d..ade289be8f 100644 --- a/src/Core/Services/Implementations/AmazonSesMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/AmazonSesMailDeliveryService.cs @@ -9,7 +9,7 @@ using Bit.Core.Utilities; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class AmazonSesMailDeliveryService : IMailDeliveryService, IDisposable { diff --git a/src/Core/Services/IMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs similarity index 73% rename from src/Core/Services/IMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs index 9247367221..1f2a024c34 100644 --- a/src/Core/Services/IMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/IMailDeliveryService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public interface IMailDeliveryService { diff --git a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs similarity index 99% rename from src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs index 04eda42d22..c78b107084 100644 --- a/src/Core/Services/Implementations/MailKitSmtpMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/MailKitSmtpMailDeliveryService.cs @@ -7,7 +7,7 @@ using MailKit.Net.Smtp; using Microsoft.Extensions.Logging; using MimeKit; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class MailKitSmtpMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs similarity index 96% rename from src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs index e088410967..1e34e1f842 100644 --- a/src/Core/Services/Implementations/MultiServiceMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/MultiServiceMailDeliveryService.cs @@ -3,7 +3,7 @@ using Bit.Core.Settings; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class MultiServiceMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs similarity index 82% rename from src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs index 96b97b14f5..d8194ffb18 100644 --- a/src/Core/Services/NoopImplementations/NoopMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/NoopMailDeliveryService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class NoopMailDeliveryService : IMailDeliveryService { diff --git a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs b/src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs similarity index 98% rename from src/Core/Services/Implementations/SendGridMailDeliveryService.cs rename to src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs index 773f87931d..10afcc539a 100644 --- a/src/Core/Services/Implementations/SendGridMailDeliveryService.cs +++ b/src/Core/Platform/Mail/Delivery/SendGridMailDeliveryService.cs @@ -6,7 +6,7 @@ using Microsoft.Extensions.Logging; using SendGrid; using SendGrid.Helpers.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Delivery; public class SendGridMailDeliveryService : IMailDeliveryService, IDisposable { diff --git a/src/Core/Services/Implementations/AzureQueueMailService.cs b/src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs similarity index 91% rename from src/Core/Services/Implementations/AzureQueueMailService.cs rename to src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs index 92d6fd17bb..c88090a954 100644 --- a/src/Core/Services/Implementations/AzureQueueMailService.cs +++ b/src/Core/Platform/Mail/Enqueuing/AzureQueueMailService.cs @@ -1,10 +1,10 @@ using Azure.Storage.Queues; using Bit.Core.Models.Mail; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; -namespace Bit.Core.Services; - +namespace Bit.Core.Platform.Mail.Enqueuing; public class AzureQueueMailService : AzureQueueService, IMailEnqueuingService { public AzureQueueMailService(GlobalSettings globalSettings) : base( diff --git a/src/Core/Services/Implementations/BlockingMailQueueService.cs b/src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs similarity index 91% rename from src/Core/Services/Implementations/BlockingMailQueueService.cs rename to src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs index 0323b09af7..e75874af16 100644 --- a/src/Core/Services/Implementations/BlockingMailQueueService.cs +++ b/src/Core/Platform/Mail/Enqueuing/BlockingMailQueueService.cs @@ -1,7 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; - +namespace Bit.Core.Platform.Mail.Enqueuing; public class BlockingMailEnqueuingService : IMailEnqueuingService { public async Task EnqueueAsync(IMailQueueMessage message, Func fallback) diff --git a/src/Core/Services/IMailEnqueuingService.cs b/src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs similarity index 86% rename from src/Core/Services/IMailEnqueuingService.cs rename to src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs index 19dc33f19e..d74f9160e4 100644 --- a/src/Core/Services/IMailEnqueuingService.cs +++ b/src/Core/Platform/Mail/Enqueuing/IMailEnqueuingService.cs @@ -1,6 +1,6 @@ using Bit.Core.Models.Mail; -namespace Bit.Core.Services; +namespace Bit.Core.Platform.Mail.Enqueuing; public interface IMailEnqueuingService { diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Platform/Mail/HandlebarsMailService.cs similarity index 95% rename from src/Core/Services/Implementations/HandlebarsMailService.cs rename to src/Core/Platform/Mail/HandlebarsMailService.cs index 9728c2e727..072fe79e71 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Platform/Mail/HandlebarsMailService.cs @@ -19,6 +19,8 @@ using Bit.Core.Models.Mail.Auth; using Bit.Core.Models.Mail.Billing; using Bit.Core.Models.Mail.FamiliesForEnterprise; using Bit.Core.Models.Mail.Provider; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; using Bit.Core.SecretsManager.Models.Mail; using Bit.Core.Settings; using Bit.Core.Utilities; @@ -26,9 +28,11 @@ using Bit.Core.Vault.Models.Data; using Core.Auth.Enums; using HandlebarsDotNet; using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; -namespace Bit.Core.Services; +namespace Bit.Core.Services.Mail; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public class HandlebarsMailService : IMailService { private const string Namespace = "Bit.Core.MailTemplates.Handlebars"; @@ -39,6 +43,7 @@ public class HandlebarsMailService : IMailService private readonly IMailDeliveryService _mailDeliveryService; private readonly IMailEnqueuingService _mailEnqueuingService; private readonly IDistributedCache _distributedCache; + private readonly ILogger _logger; private readonly Dictionary> _templateCache = new(); private bool _registeredHelpersAndPartials = false; @@ -47,12 +52,14 @@ public class HandlebarsMailService : IMailService GlobalSettings globalSettings, IMailDeliveryService mailDeliveryService, IMailEnqueuingService mailEnqueuingService, - IDistributedCache distributedCache) + IDistributedCache distributedCache, + ILogger logger) { _globalSettings = globalSettings; _mailDeliveryService = mailDeliveryService; _mailEnqueuingService = mailEnqueuingService; _distributedCache = distributedCache; + _logger = logger; } public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) @@ -220,6 +227,27 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) + { + var message = CreateDefaultMessage(subject, email); + var requestDateTime = DateTime.UtcNow; + var model = new DefaultEmailOtpViewModel + { + Token = token, + Expiry = "5", // This should be configured through the OTPDefaultTokenProviderOptions but for now we will hardcode it to 5 minutes. + TheDate = requestDateTime.ToLongDateString(), + TheTime = requestDateTime.ToShortTimeString(), + TimeZone = _utcTimeZoneDisplay, + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName, + }; + await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmailv2", model); + message.MetaData.Add("SendGridBypassListManagement", true); + // TODO - PM-25380 change to string constant + message.Category = "SendEmailOtp"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { // Check if we've sent this email within the last hour @@ -649,7 +677,7 @@ public class HandlebarsMailService : IMailService await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public async Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { var message = CreateDefaultMessage("Your admin has initiated account recovery", email); var model = new AdminResetPasswordViewModel() @@ -708,6 +736,12 @@ public class HandlebarsMailService : IMailService private async Task ReadSourceAsync(string templateName) { + var diskSource = await ReadSourceFromDiskAsync(templateName); + if (!string.IsNullOrWhiteSpace(diskSource)) + { + return diskSource; + } + var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; var fullTemplateName = $"{Namespace}.{templateName}.hbs"; if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) @@ -721,6 +755,42 @@ public class HandlebarsMailService : IMailService } } + private async Task ReadSourceFromDiskAsync(string templateName) + { + if (!_globalSettings.SelfHosted) + { + return null; + } + try + { + var templateFileSuffix = ".html"; + if (templateName.EndsWith(".txt")) + { + templateFileSuffix = ".txt"; + } + else if (!templateName.EndsWith(".html")) + { + // unexpected suffix + return null; + } + var suffixPosition = templateName.LastIndexOf(templateFileSuffix); + var templateNameNoSuffix = templateName.Substring(0, suffixPosition); + var templatePathNoSuffix = templateNameNoSuffix.Replace(".", "/"); + var diskPath = $"{_globalSettings.MailTemplateDirectory}/{templatePathNoSuffix}{templateFileSuffix}.hbs"; + var directory = Path.GetDirectoryName(diskPath); + if (Directory.Exists(directory) && File.Exists(diskPath)) + { + var fileContents = await File.ReadAllTextAsync(diskPath); + return fileContents; + } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to read mail template from disk."); + } + return null; + } + private async Task RegisterHelpersAndPartialsAsync() { if (_registeredHelpersAndPartials) diff --git a/src/Core/Services/IMailService.cs b/src/Core/Platform/Mail/IMailService.cs similarity index 88% rename from src/Core/Services/IMailService.cs rename to src/Core/Platform/Mail/IMailService.cs index 6e61c4f8dd..52fbdb9b6d 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Platform/Mail/IMailService.cs @@ -1,9 +1,8 @@ -#nullable enable - -using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Models.Data.Organizations; @@ -13,6 +12,7 @@ using Core.Auth.Enums; namespace Bit.Core.Services; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public interface IMailService { Task SendWelcomeEmailAsync(User user); @@ -31,6 +31,16 @@ public interface IMailService Task SendChangeEmailEmailAsync(string newEmailAddress, string token); Task SendTwoFactorEmailAsync(string email, string accountEmail, string token, string deviceIp, string deviceType, TwoFactorEmailPurpose purpose); Task SendSendEmailOtpEmailAsync(string email, string token, string subject); + /// + /// has a default expiry of 5 minutes so we set the expiry to that value int he view model. + /// Sends OTP code token to the specified email address. + /// will replace when MJML templates are fully accepted. + /// + /// Email address to send the OTP to + /// Otp code token + /// subject line of the email + /// Task + Task SendSendEmailOtpEmailv2Async(string email, string token, string subject); Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType type, DateTime utcNow, string ip); Task SendNoMasterPasswordHintEmailAsync(string email); Task SendMasterPasswordHintEmailAsync(string email, string hint); @@ -81,7 +91,7 @@ public interface IMailService Task SendEmergencyAccessRecoveryReminder(EmergencyAccess emergencyAccess, string initiatingName, string email); Task SendEmergencyAccessRecoveryTimedOut(EmergencyAccess ea, string initiatingName, string email); Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage); - Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName); + Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName); Task SendProviderSetupInviteEmailAsync(Provider provider, string token, string email); Task SendBusinessUnitConversionInviteAsync(Organization organization, string token, string email); Task SendProviderInviteEmailAsync(string providerName, ProviderUser providerUser, string token, string email); diff --git a/src/Core/Platform/Mail/Mailer/BaseMail.cs b/src/Core/Platform/Mail/Mailer/BaseMail.cs new file mode 100644 index 0000000000..0fd6b79aba --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/BaseMail.cs @@ -0,0 +1,54 @@ +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// BaseMail describes a model for emails. It contains metadata about the email such as recipients, +/// subject, and an optional category for processing at the upstream email delivery service. +/// +/// Each BaseMail must have a view model that inherits from BaseMailView. The view model is used to +/// generate the text part and HTML body. +/// +public abstract class BaseMail where TView : BaseMailView +{ + /// + /// Email recipients. + /// + public required IEnumerable ToEmails { get; set; } + + /// + /// The subject of the email. + /// + public abstract string Subject { get; } + + /// + /// An optional category for processing at the upstream email delivery service. + /// + public string? Category { get; } + + /// + /// Allows you to override and ignore the suppression list for this email. + /// + /// Warning: This should be used with caution, valid reasons are primarily account recovery, email OTP. + /// + public virtual bool IgnoreSuppressList { get; } = false; + + /// + /// View model for the email body. + /// + public required TView View { get; set; } +} + +/// +/// Each MailView consists of two body parts: a text part and an HTML part and the filename must be +/// relative to the viewmodel and match the following pattern: +/// - `{ClassName}.html.hbs` for the HTML part +/// - `{ClassName}.text.hbs` for the text part +/// +public abstract class BaseMailView +{ + /// + /// Current year. + /// + public string CurrentYear => DateTime.UtcNow.Year.ToString(); +} diff --git a/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs b/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs new file mode 100644 index 0000000000..baba5b8015 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/HandlebarMailRenderer.cs @@ -0,0 +1,132 @@ +#nullable enable +using System.Collections.Concurrent; +using System.Reflection; +using Bit.Core.Settings; +using HandlebarsDotNet; +using Microsoft.Extensions.Logging; + +namespace Bit.Core.Platform.Mail.Mailer; +public class HandlebarMailRenderer : IMailRenderer +{ + /// + /// Lazy-initialized Handlebars instance. Thread-safe and ensures initialization occurs only once. + /// + private readonly Lazy> _handlebarsTask; + + /// + /// Helper function that returns the handlebar instance. + /// + private Task GetHandlebars() => _handlebarsTask.Value; + + /// + /// This dictionary is used to cache compiled templates in a thread-safe manner. + /// + private readonly ConcurrentDictionary>>> _templateCache = new(); + + private readonly ILogger _logger; + private readonly GlobalSettings _globalSettings; + + public HandlebarMailRenderer(ILogger logger, GlobalSettings globalSettings) + { + _logger = logger; + _globalSettings = globalSettings; + + _handlebarsTask = new Lazy>(InitializeHandlebarsAsync, LazyThreadSafetyMode.ExecutionAndPublication); + } + + public async Task<(string html, string txt)> RenderAsync(BaseMailView model) + { + var html = await CompileTemplateAsync(model, "html"); + var txt = await CompileTemplateAsync(model, "text"); + + return (html, txt); + } + + private async Task CompileTemplateAsync(BaseMailView model, string type) + { + var templateName = $"{model.GetType().FullName}.{type}.hbs"; + var assembly = model.GetType().Assembly; + + // GetOrAdd is atomic - only one Lazy will be stored per templateName. + // The Lazy with ExecutionAndPublication ensures the compilation happens exactly once. + var lazyTemplate = _templateCache.GetOrAdd( + templateName, + key => new Lazy>>( + () => CompileTemplateInternalAsync(assembly, key), + LazyThreadSafetyMode.ExecutionAndPublication)); + + var template = await lazyTemplate.Value; + return template(model); + } + + private async Task> CompileTemplateInternalAsync(Assembly assembly, string templateName) + { + var source = await ReadSourceAsync(assembly, templateName); + var handlebars = await GetHandlebars(); + return handlebars.Compile(source); + } + + private async Task ReadSourceAsync(Assembly assembly, string template) + { + if (assembly.GetManifestResourceNames().All(f => f != template)) + { + throw new FileNotFoundException("Template not found: " + template); + } + + var diskSource = await ReadSourceFromDiskAsync(template); + if (!string.IsNullOrWhiteSpace(diskSource)) + { + return diskSource; + } + + await using var s = assembly.GetManifestResourceStream(template)!; + using var sr = new StreamReader(s); + return await sr.ReadToEndAsync(); + } + + private async Task ReadSourceFromDiskAsync(string template) + { + if (!_globalSettings.SelfHosted) + { + return null; + } + + try + { + var diskPath = Path.GetFullPath(Path.Combine(_globalSettings.MailTemplateDirectory, template)); + var baseDirectory = Path.GetFullPath(_globalSettings.MailTemplateDirectory); + + // Ensure the resolved path is within the configured directory + if (!diskPath.StartsWith(baseDirectory + Path.DirectorySeparatorChar, StringComparison.OrdinalIgnoreCase) && + !diskPath.Equals(baseDirectory, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning("Template path traversal attempt detected: {Template}", template); + return null; + } + + if (File.Exists(diskPath)) + { + var fileContents = await File.ReadAllTextAsync(diskPath); + return fileContents; + } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to read mail template from disk: {TemplateName}", template); + } + + return null; + } + + private async Task InitializeHandlebarsAsync() + { + var handlebars = Handlebars.Create(); + + // TODO: Do we still need layouts with MJML? + var assembly = typeof(HandlebarMailRenderer).Assembly; + var layoutSource = await ReadSourceAsync(assembly, "Bit.Core.MailTemplates.Handlebars.Layouts.Full.html.hbs"); + handlebars.RegisterTemplate("FullHtmlLayout", layoutSource); + + return handlebars; + } +} diff --git a/src/Core/Platform/Mail/Mailer/IMailRenderer.cs b/src/Core/Platform/Mail/Mailer/IMailRenderer.cs new file mode 100644 index 0000000000..7f392df479 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/IMailRenderer.cs @@ -0,0 +1,7 @@ +#nullable enable +namespace Bit.Core.Platform.Mail.Mailer; + +public interface IMailRenderer +{ + Task<(string html, string txt)> RenderAsync(BaseMailView model); +} diff --git a/src/Core/Platform/Mail/Mailer/IMailer.cs b/src/Core/Platform/Mail/Mailer/IMailer.cs new file mode 100644 index 0000000000..6dc3eec46f --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/IMailer.cs @@ -0,0 +1,15 @@ +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// Generic mailer interface for sending email messages. +/// +public interface IMailer +{ + /// + /// Sends an email message. + /// + /// + public Task SendEmail(BaseMail message) where T : BaseMailView; +} diff --git a/src/Core/Platform/Mail/Mailer/Mailer.cs b/src/Core/Platform/Mail/Mailer/Mailer.cs new file mode 100644 index 0000000000..f5e8d35d58 --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/Mailer.cs @@ -0,0 +1,32 @@ +using Bit.Core.Models.Mail; +using Bit.Core.Platform.Mail.Delivery; + +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +public class Mailer(IMailRenderer renderer, IMailDeliveryService mailDeliveryService) : IMailer +{ + public async Task SendEmail(BaseMail message) where T : BaseMailView + { + var content = await renderer.RenderAsync(message.View); + + var metadata = new Dictionary(); + if (message.IgnoreSuppressList) + { + metadata.Add("SendGridBypassListManagement", true); + } + + var mailMessage = new MailMessage + { + ToEmails = message.ToEmails, + Subject = message.Subject, + MetaData = metadata, + HtmlContent = content.html, + TextContent = content.txt, + Category = message.Category, + }; + + await mailDeliveryService.SendEmailAsync(mailMessage); + } +} diff --git a/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs b/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs new file mode 100644 index 0000000000..cc56b3ec5a --- /dev/null +++ b/src/Core/Platform/Mail/Mailer/MailerServiceCollectionExtensions.cs @@ -0,0 +1,27 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Bit.Core.Platform.Mail.Mailer; + +#nullable enable + +/// +/// Extension methods for adding the Mailer feature to the service collection. +/// +public static class MailerServiceCollectionExtensions +{ + /// + /// Adds the Mailer services to the . + /// This includes the mail renderer and mailer for sending templated emails. + /// This method is safe to be run multiple times. + /// + /// The to add services to. + /// The for additional chaining. + public static IServiceCollection AddMailer(this IServiceCollection services) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + + return services; + } +} diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Platform/Mail/NoopMailService.cs similarity index 97% rename from src/Core/Services/NoopImplementations/NoopMailService.cs rename to src/Core/Platform/Mail/NoopMailService.cs index 7ec05bb1f9..45a860a155 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Platform/Mail/NoopMailService.cs @@ -13,6 +13,7 @@ using Core.Auth.Enums; namespace Bit.Core.Services; +[Obsolete("The IMailService has been deprecated in favor of the IMailer. All new emails should be sent with an IMailer implementation.")] public class NoopMailService : IMailService { public Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail) @@ -98,6 +99,11 @@ public class NoopMailService : IMailService return Task.FromResult(0); } + public Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) + { + return Task.FromResult(0); + } + public Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { return Task.FromResult(0); @@ -216,7 +222,7 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendAdminResetPasswordEmailAsync(string email, string userName, string orgName) + public Task SendAdminResetPasswordEmailAsync(string email, string? userName, string orgName) { return Task.FromResult(0); } diff --git a/src/Core/Platform/Mail/README.md b/src/Core/Platform/Mail/README.md new file mode 100644 index 0000000000..b5caca62be --- /dev/null +++ b/src/Core/Platform/Mail/README.md @@ -0,0 +1,213 @@ +# Mail Services +## `MailService` + +The `MailService` and its implementation in `HandlebarsMailService` has been deprecated in favor of the `Mailer` implementation. + +New emails should be implemented using [MJML](../../MailTemplates/README.md) and the `Mailer`. + +## `Mailer` + +The Mailer feature provides a structured, type-safe approach to sending emails in the Bitwarden server application. It +uses Handlebars templates to render both HTML and plain text email content. + +### Architecture + +The Mailer system consists of four main components: + +1. **IMailer** - Service interface for sending emails +2. **BaseMail** - Abstract base class defining email metadata (recipients, subject, category) +3. **BaseMailView** - Abstract base class for email template view models +4. **IMailRenderer** - Internal interface for rendering templates (implemented by `HandlebarMailRenderer`) + +### How To Use + +1. Define a view model that inherits from `BaseMailView` with properties for template data +2. Create Handlebars templates (`.html.hbs` and `.text.hbs`) as embedded resources, preferably using the MJML pipeline, + `/src/Core/MailTemplates/Mjml`. +3. Define an email class that inherits from `BaseMail` with metadata like subject +4. Use `IMailer.SendEmail()` to render and send the email + +### Creating a New Email + +#### Step 1: Define the Email & View Model + +Create a class that inherits from `BaseMailView`: + +```csharp +using Bit.Core.Platform.Mailer; + +namespace MyApp.Emails; + +public class WelcomeEmailView : BaseMailView +{ + public required string UserName { get; init; } + public required string ActivationUrl { get; init; } +} + +public class WelcomeEmail : BaseMail +{ + public override string Subject => "Welcome to Bitwarden"; +} +``` + +#### Step 2: Create Handlebars Templates + +Create two template files as embedded resources next to your view model. **Important**: The file names must be located +directly next to the `ViewClass` and match the name of the view. + +**WelcomeEmailView.html.hbs** (HTML version): + +```handlebars +

Welcome, {{ UserName }}!

+

Thank you for joining Bitwarden.

+

+ Activate your account +

+

© {{ CurrentYear }} Bitwarden Inc.

+``` + +**WelcomeEmailView.text.hbs** (plain text version): + +```handlebars +Welcome, {{ UserName }}! + +Thank you for joining Bitwarden. + +Activate your account: {{ ActivationUrl }} + +� {{ CurrentYear }} Bitwarden Inc. +``` + +**Important**: Template files must be configured as embedded resources in your `.csproj`: + +```xml + + + + +``` + +#### Step 3: Send the Email + +Inject `IMailer` and send the email, this may be done in a service, command or some other application layer. + +```csharp +public class SomeService +{ + private readonly IMailer _mailer; + + public SomeService(IMailer mailer) + { + _mailer = mailer; + } + + public async Task SendWelcomeEmailAsync(string email, string userName, string activationUrl) + { + var mail = new WelcomeEmail + { + ToEmails = [email], + View = new WelcomeEmailView + { + UserName = userName, + ActivationUrl = activationUrl + } + }; + + await _mailer.SendEmail(mail); + } +} +``` + +### Advanced Features + +#### Multiple Recipients + +Send to multiple recipients by providing multiple email addresses: + +```csharp +var mail = new WelcomeEmail +{ + ToEmails = ["user1@example.com", "user2@example.com"], + View = new WelcomeEmailView { /* ... */ } +}; +``` + +#### Bypass Suppression List + +For critical emails like account recovery or email OTP, you can bypass the suppression list: + +```csharp +public class PasswordResetEmail : BaseMail +{ + public override string Subject => "Reset Your Password"; + public override bool IgnoreSuppressList => true; // Use with caution +} +``` + +**Warning**: Only use `IgnoreSuppressList = true` for critical account recovery or authentication emails. + +#### Email Categories + +Optionally categorize emails for processing at the upstream email delivery service: + +```csharp +public class MarketingEmail : BaseMail +{ + public override string Subject => "Latest Updates"; + public string? Category => "marketing"; +} +``` + +### Built-in View Properties + +All view models inherit from `BaseMailView`, which provides: + +- **CurrentYear** - The current UTC year (useful for copyright notices) + +```handlebars + +
© {{ CurrentYear }} Bitwarden Inc.
+``` + +### Template Naming Convention + +Templates must follow this naming convention: + +- HTML template: `{ViewModelFullName}.html.hbs` +- Text template: `{ViewModelFullName}.text.hbs` + +For example, if your view model is `Bit.Core.Auth.Models.Mail.VerifyEmailView`, the templates must be: + +- `Bit.Core.Auth.Models.Mail.VerifyEmailView.html.hbs` +- `Bit.Core.Auth.Models.Mail.VerifyEmailView.text.hbs` + +## Dependency Injection + +Register the Mailer services in your DI container using the extension method: + +```csharp +using Bit.Core.Platform.Mailer; + +services.AddMailer(); +``` + +Or manually register the services: + +```csharp +using Microsoft.Extensions.DependencyInjection.Extensions; + +services.TryAddSingleton(); +services.TryAddSingleton(); +``` + +### Performance Notes + +- **Template caching** - `HandlebarMailRenderer` automatically caches compiled templates +- **Lazy initialization** - Handlebars is initialized only when first needed +- **Thread-safe** - The renderer is thread-safe for concurrent email rendering + +# Overriding email templates from disk + +The mail services support loading the mail template from disk. This is intended to be used by self-hosted customers who want to modify their email appearance. These overrides are not intended to be used during local development, as any changes there would not be reflected in the templates used in a normal deployment configuration. + +Any customer using this override has worked with Bitwarden support on an approved implementation and has acknowledged that they are responsible for reacting to any changes made to the templates as a part of the Bitwarden development process. This includes, but is not limited to, changes in Handlebars property names, removal of properties from the `ViewModel` classes, and changes in template names. **Bitwarden is not responsible for maintaining backward compatibility between releases in order to support any overridden emails.** \ No newline at end of file diff --git a/src/Core/Platform/Push/IPushNotificationService.cs b/src/Core/Platform/Push/IPushNotificationService.cs index 32a488b827..b6d7d4d416 100644 --- a/src/Core/Platform/Push/IPushNotificationService.cs +++ b/src/Core/Platform/Push/IPushNotificationService.cs @@ -167,18 +167,17 @@ public interface IPushNotificationService ExcludeCurrentContext = false, }); - Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false) - => PushAsync(new PushNotification + Task PushLogOutAsync(Guid userId, bool excludeCurrentContextFromPush = false, + PushNotificationLogOutReason? reason = null) + => PushAsync(new PushNotification { Type = PushType.LogOut, Target = NotificationTarget.User, TargetId = userId, - Payload = new UserPushNotification + Payload = new LogOutPushNotification { UserId = userId, -#pragma warning disable BWP0001 // Type or member is obsolete - Date = TimeProvider.GetUtcNow().UtcDateTime, -#pragma warning restore BWP0001 // Type or member is obsolete + Reason = reason }, ExcludeCurrentContext = excludeCurrentContextFromPush, }); diff --git a/src/Core/Platform/Push/PushType.cs b/src/Core/Platform/Push/PushType.cs index 7765c1aa66..93eca86243 100644 --- a/src/Core/Platform/Push/PushType.cs +++ b/src/Core/Platform/Push/PushType.cs @@ -55,7 +55,7 @@ public enum PushType : byte [NotificationInfo("not-specified", typeof(Models.UserPushNotification))] SyncSettings = 10, - [NotificationInfo("not-specified", typeof(Models.UserPushNotification))] + [NotificationInfo("not-specified", typeof(Models.LogOutPushNotification))] LogOut = 11, [NotificationInfo("@bitwarden/team-tools-dev", typeof(Models.SyncSendPushNotification))] diff --git a/src/Core/Resources/SharedResources.en.resx b/src/Core/Resources/SharedResources.en.resx index 17b4489454..ca150f2106 100644 --- a/src/Core/Resources/SharedResources.en.resx +++ b/src/Core/Resources/SharedResources.en.resx @@ -389,7 +389,7 @@ If SAML Binding Type is set to artifact, identity provider resolution service URL is required. - If Identity Provider Entity ID is not a URL, single sign on service URL is required. + Single sign on service URL is required. The configured authentication scheme is not valid: "{0}" @@ -508,9 +508,15 @@ Supplied userId and token did not match. + + User should have been defined by this point. + Could not find organization for '{0}' + + Could not find organization user for user '{0}' organization '{1}' + No seats available for organization, '{0}' diff --git a/src/Core/SecretsManager/Entities/SecretVersion.cs b/src/Core/SecretsManager/Entities/SecretVersion.cs new file mode 100644 index 0000000000..cee447bd2a --- /dev/null +++ b/src/Core/SecretsManager/Entities/SecretVersion.cs @@ -0,0 +1,28 @@ +#nullable enable +using Bit.Core.Entities; +using Bit.Core.Utilities; + +namespace Bit.Core.SecretsManager.Entities; + +public class SecretVersion : ITableObject +{ + public Guid Id { get; set; } + + public Guid SecretId { get; set; } + + public string Value { get; set; } = string.Empty; + + public DateTime VersionDate { get; set; } + + public Guid? EditorServiceAccountId { get; set; } + + public Guid? EditorOrganizationUserId { get; set; } + + public void SetNewId() + { + if (Id == default(Guid)) + { + Id = CoreHelpers.GenerateComb(); + } + } +} diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index 8a41263956..6b2c3c299e 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -3,58 +3,47 @@ using Bit.Core.Models.BitStripe; using Stripe; +using Stripe.Tax; namespace Bit.Core.Services; public interface IStripeAdapter { - Task CustomerCreateAsync(Stripe.CustomerCreateOptions customerCreateOptions); - Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null); - Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null); - Task CustomerDeleteAsync(string id); - Task> CustomerListPaymentMethods(string id, CustomerListPaymentMethodsOptions options = null); + Task CustomerCreateAsync(CustomerCreateOptions customerCreateOptions); + Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null); + Task CustomerGetAsync(string id, CustomerGetOptions options = null); + Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null); + Task CustomerDeleteAsync(string id); + Task> CustomerListPaymentMethods(string id, CustomerPaymentMethodListOptions options = null); Task CustomerBalanceTransactionCreate(string customerId, CustomerBalanceTransactionCreateOptions options); - Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions subscriptionCreateOptions); - Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null); - - /// - /// Retrieves a subscription object for a provider. - /// - /// The subscription ID. - /// The provider ID. - /// Additional options. - /// The subscription object. - /// Thrown when the subscription doesn't belong to the provider. - Task ProviderSubscriptionGetAsync(string id, Guid providerId, Stripe.SubscriptionGetOptions options = null); - - Task> SubscriptionListAsync(StripeSubscriptionListOptions subscriptionSearchOptions); - Task SubscriptionUpdateAsync(string id, Stripe.SubscriptionUpdateOptions options = null); - Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null); - Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); - Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); - Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); - Task> InvoiceSearchAsync(InvoiceSearchOptions options); - Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); - Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); - Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options); - Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null); - Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null); - Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null); - IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options); - IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options); - Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); - Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); - Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); - Task TaxIdDeleteAsync(string customerId, string taxIdId, Stripe.TaxIdDeleteOptions options = null); - Task> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null); - Task> ChargeListAsync(Stripe.ChargeListOptions options); - Task RefundCreateAsync(Stripe.RefundCreateOptions options); - Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null); - Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null); - Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null); - Task> PriceListAsync(Stripe.PriceListOptions options = null); + Task SubscriptionCreateAsync(SubscriptionCreateOptions subscriptionCreateOptions); + Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null); + Task SubscriptionUpdateAsync(string id, SubscriptionUpdateOptions options = null); + Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null); + Task InvoiceGetAsync(string id, InvoiceGetOptions options); + Task> InvoiceListAsync(StripeInvoiceListOptions options); + Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); + Task> InvoiceSearchAsync(InvoiceSearchOptions options); + Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options); + Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options); + Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options); + Task InvoicePayAsync(string id, InvoicePayOptions options = null); + Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null); + Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null); + IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options); + IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options); + Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null); + Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null); + Task TaxIdCreateAsync(string id, TaxIdCreateOptions options); + Task TaxIdDeleteAsync(string customerId, string taxIdId, TaxIdDeleteOptions options = null); + Task> TaxRegistrationsListAsync(RegistrationListOptions options = null); + Task> ChargeListAsync(ChargeListOptions options); + Task RefundCreateAsync(RefundCreateOptions options); + Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null); + Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null); + Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null); + Task> PriceListAsync(PriceListOptions options = null); Task SetupIntentCreate(SetupIntentCreateOptions options); Task> SetupIntentList(SetupIntentListOptions options); Task SetupIntentCancel(string id, SetupIntentCancelOptions options = null); diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index 03d1776e90..3d1663f021 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -3,68 +3,74 @@ using Bit.Core.Models.BitStripe; using Stripe; +using Stripe.Tax; namespace Bit.Core.Services; public class StripeAdapter : IStripeAdapter { - private readonly Stripe.CustomerService _customerService; - private readonly Stripe.SubscriptionService _subscriptionService; - private readonly Stripe.InvoiceService _invoiceService; - private readonly Stripe.PaymentMethodService _paymentMethodService; - private readonly Stripe.TaxIdService _taxIdService; - private readonly Stripe.ChargeService _chargeService; - private readonly Stripe.RefundService _refundService; - private readonly Stripe.CardService _cardService; - private readonly Stripe.BankAccountService _bankAccountService; - private readonly Stripe.PlanService _planService; - private readonly Stripe.PriceService _priceService; - private readonly Stripe.SetupIntentService _setupIntentService; + private readonly CustomerService _customerService; + private readonly SubscriptionService _subscriptionService; + private readonly InvoiceService _invoiceService; + private readonly PaymentMethodService _paymentMethodService; + private readonly TaxIdService _taxIdService; + private readonly ChargeService _chargeService; + private readonly RefundService _refundService; + private readonly CardService _cardService; + private readonly BankAccountService _bankAccountService; + private readonly PlanService _planService; + private readonly PriceService _priceService; + private readonly SetupIntentService _setupIntentService; private readonly Stripe.TestHelpers.TestClockService _testClockService; private readonly CustomerBalanceTransactionService _customerBalanceTransactionService; private readonly Stripe.Tax.RegistrationService _taxRegistrationService; + private readonly CalculationService _calculationService; public StripeAdapter() { - _customerService = new Stripe.CustomerService(); - _subscriptionService = new Stripe.SubscriptionService(); - _invoiceService = new Stripe.InvoiceService(); - _paymentMethodService = new Stripe.PaymentMethodService(); - _taxIdService = new Stripe.TaxIdService(); - _chargeService = new Stripe.ChargeService(); - _refundService = new Stripe.RefundService(); - _cardService = new Stripe.CardService(); - _bankAccountService = new Stripe.BankAccountService(); - _priceService = new Stripe.PriceService(); - _planService = new Stripe.PlanService(); + _customerService = new CustomerService(); + _subscriptionService = new SubscriptionService(); + _invoiceService = new InvoiceService(); + _paymentMethodService = new PaymentMethodService(); + _taxIdService = new TaxIdService(); + _chargeService = new ChargeService(); + _refundService = new RefundService(); + _cardService = new CardService(); + _bankAccountService = new BankAccountService(); + _priceService = new PriceService(); + _planService = new PlanService(); _setupIntentService = new SetupIntentService(); _testClockService = new Stripe.TestHelpers.TestClockService(); _customerBalanceTransactionService = new CustomerBalanceTransactionService(); _taxRegistrationService = new Stripe.Tax.RegistrationService(); + _calculationService = new CalculationService(); } - public Task CustomerCreateAsync(Stripe.CustomerCreateOptions options) + public Task CustomerCreateAsync(CustomerCreateOptions options) { return _customerService.CreateAsync(options); } - public Task CustomerGetAsync(string id, Stripe.CustomerGetOptions options = null) + public Task CustomerDeleteDiscountAsync(string customerId, CustomerDeleteDiscountOptions options = null) => + _customerService.DeleteDiscountAsync(customerId, options); + + public Task CustomerGetAsync(string id, CustomerGetOptions options = null) { return _customerService.GetAsync(id, options); } - public Task CustomerUpdateAsync(string id, Stripe.CustomerUpdateOptions options = null) + public Task CustomerUpdateAsync(string id, CustomerUpdateOptions options = null) { return _customerService.UpdateAsync(id, options); } - public Task CustomerDeleteAsync(string id) + public Task CustomerDeleteAsync(string id) { return _customerService.DeleteAsync(id); } public async Task> CustomerListPaymentMethods(string id, - CustomerListPaymentMethodsOptions options = null) + CustomerPaymentMethodListOptions options = null) { var paymentMethods = await _customerService.ListPaymentMethodsAsync(id, options); return paymentMethods.Data; @@ -74,12 +80,12 @@ public class StripeAdapter : IStripeAdapter CustomerBalanceTransactionCreateOptions options) => await _customerBalanceTransactionService.CreateAsync(customerId, options); - public Task SubscriptionCreateAsync(Stripe.SubscriptionCreateOptions options) + public Task SubscriptionCreateAsync(SubscriptionCreateOptions options) { return _subscriptionService.CreateAsync(options); } - public Task SubscriptionGetAsync(string id, Stripe.SubscriptionGetOptions options = null) + public Task SubscriptionGetAsync(string id, SubscriptionGetOptions options = null) { return _subscriptionService.GetAsync(id, options); } @@ -98,28 +104,23 @@ public class StripeAdapter : IStripeAdapter throw new InvalidOperationException("Subscription does not belong to the provider."); } - public Task SubscriptionUpdateAsync(string id, - Stripe.SubscriptionUpdateOptions options = null) + public Task SubscriptionUpdateAsync(string id, + SubscriptionUpdateOptions options = null) { return _subscriptionService.UpdateAsync(id, options); } - public Task SubscriptionCancelAsync(string Id, Stripe.SubscriptionCancelOptions options = null) + public Task SubscriptionCancelAsync(string Id, SubscriptionCancelOptions options = null) { return _subscriptionService.CancelAsync(Id, options); } - public Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options) - { - return _invoiceService.UpcomingAsync(options); - } - - public Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options) + public Task InvoiceGetAsync(string id, InvoiceGetOptions options) { return _invoiceService.GetAsync(id, options); } - public async Task> InvoiceListAsync(StripeInvoiceListOptions options) + public async Task> InvoiceListAsync(StripeInvoiceListOptions options) { if (!options.SelectAll) { @@ -128,7 +129,7 @@ public class StripeAdapter : IStripeAdapter options.Limit = 100; - var invoices = new List(); + var invoices = new List(); await foreach (var invoice in _invoiceService.ListAutoPagingAsync(options.ToInvoiceListOptions())) { @@ -143,120 +144,104 @@ public class StripeAdapter : IStripeAdapter return _invoiceService.CreatePreviewAsync(options); } - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) + public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) => (await _invoiceService.SearchAsync(options)).Data; - public Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options) + public Task InvoiceUpdateAsync(string id, InvoiceUpdateOptions options) { return _invoiceService.UpdateAsync(id, options); } - public Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options) + public Task InvoiceFinalizeInvoiceAsync(string id, InvoiceFinalizeOptions options) { return _invoiceService.FinalizeInvoiceAsync(id, options); } - public Task InvoiceSendInvoiceAsync(string id, Stripe.InvoiceSendOptions options) + public Task InvoiceSendInvoiceAsync(string id, InvoiceSendOptions options) { return _invoiceService.SendInvoiceAsync(id, options); } - public Task InvoicePayAsync(string id, Stripe.InvoicePayOptions options = null) + public Task InvoicePayAsync(string id, InvoicePayOptions options = null) { return _invoiceService.PayAsync(id, options); } - public Task InvoiceDeleteAsync(string id, Stripe.InvoiceDeleteOptions options = null) + public Task InvoiceDeleteAsync(string id, InvoiceDeleteOptions options = null) { return _invoiceService.DeleteAsync(id, options); } - public Task InvoiceVoidInvoiceAsync(string id, Stripe.InvoiceVoidOptions options = null) + public Task InvoiceVoidInvoiceAsync(string id, InvoiceVoidOptions options = null) { return _invoiceService.VoidInvoiceAsync(id, options); } - public IEnumerable PaymentMethodListAutoPaging(Stripe.PaymentMethodListOptions options) + public IEnumerable PaymentMethodListAutoPaging(PaymentMethodListOptions options) { return _paymentMethodService.ListAutoPaging(options); } - public IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options) + public IAsyncEnumerable PaymentMethodListAutoPagingAsync(PaymentMethodListOptions options) => _paymentMethodService.ListAutoPagingAsync(options); - public Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null) + public Task PaymentMethodAttachAsync(string id, PaymentMethodAttachOptions options = null) { return _paymentMethodService.AttachAsync(id, options); } - public Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null) + public Task PaymentMethodDetachAsync(string id, PaymentMethodDetachOptions options = null) { return _paymentMethodService.DetachAsync(id, options); } - public Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null) + public Task PlanGetAsync(string id, PlanGetOptions options = null) { return _planService.GetAsync(id, options); } - public Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options) + public Task TaxIdCreateAsync(string id, TaxIdCreateOptions options) { return _taxIdService.CreateAsync(id, options); } - public Task TaxIdDeleteAsync(string customerId, string taxIdId, - Stripe.TaxIdDeleteOptions options = null) + public Task TaxIdDeleteAsync(string customerId, string taxIdId, + TaxIdDeleteOptions options = null) { return _taxIdService.DeleteAsync(customerId, taxIdId); } - public Task> TaxRegistrationsListAsync(Stripe.Tax.RegistrationListOptions options = null) + public Task> TaxRegistrationsListAsync(RegistrationListOptions options = null) { return _taxRegistrationService.ListAsync(options); } - public Task> ChargeListAsync(Stripe.ChargeListOptions options) + public Task> ChargeListAsync(ChargeListOptions options) { return _chargeService.ListAsync(options); } - public Task RefundCreateAsync(Stripe.RefundCreateOptions options) + public Task RefundCreateAsync(RefundCreateOptions options) { return _refundService.CreateAsync(options); } - public Task CardDeleteAsync(string customerId, string cardId, Stripe.CardDeleteOptions options = null) + public Task CardDeleteAsync(string customerId, string cardId, CardDeleteOptions options = null) { return _cardService.DeleteAsync(customerId, cardId, options); } - public Task BankAccountCreateAsync(string customerId, Stripe.BankAccountCreateOptions options = null) + public Task BankAccountCreateAsync(string customerId, BankAccountCreateOptions options = null) { return _bankAccountService.CreateAsync(customerId, options); } - public Task BankAccountDeleteAsync(string customerId, string bankAccount, Stripe.BankAccountDeleteOptions options = null) + public Task BankAccountDeleteAsync(string customerId, string bankAccount, BankAccountDeleteOptions options = null) { return _bankAccountService.DeleteAsync(customerId, bankAccount, options); } - public async Task> SubscriptionListAsync(StripeSubscriptionListOptions options) - { - if (!options.SelectAll) - { - return (await _subscriptionService.ListAsync(options.ToStripeApiOptions())).Data; - } - - options.Limit = 100; - var items = new List(); - await foreach (var i in _subscriptionService.ListAutoPagingAsync(options.ToStripeApiOptions())) - { - items.Add(i); - } - return items; - } - - public async Task> PriceListAsync(Stripe.PriceListOptions options = null) + public async Task> PriceListAsync(PriceListOptions options = null) { return await _priceService.ListAsync(options); } diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 5b68906d8a..5dd1ff50e7 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -8,6 +8,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; using Bit.Core.Billing.Organizations.Models; +using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Tax.Requests; using Bit.Core.Billing.Tax.Responses; @@ -65,19 +66,20 @@ public class StripePaymentService : IPaymentService bool applySponsorship) { var existingPlan = await _pricingClient.GetPlanOrThrow(org.PlanType); - var sponsoredPlan = sponsorship?.PlanSponsorshipType != null ? - Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) : - null; - var subscriptionUpdate = new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); + var sponsoredPlan = sponsorship?.PlanSponsorshipType != null + ? Utilities.StaticStore.GetSponsoredPlan(sponsorship.PlanSponsorshipType.Value) + : null; + var subscriptionUpdate = + new SponsorOrganizationSubscriptionUpdate(existingPlan, sponsoredPlan, applySponsorship); await FinalizeSubscriptionChangeAsync(org, subscriptionUpdate, true); var sub = await _stripeAdapter.SubscriptionGetAsync(org.GatewaySubscriptionId); - org.ExpirationDate = sub.CurrentPeriodEnd; + org.ExpirationDate = sub.GetCurrentPeriodEnd(); if (sponsorship is not null) { - sponsorship.ValidUntil = sub.CurrentPeriodEnd; + sponsorship.ValidUntil = sub.GetCurrentPeriodEnd(); } } @@ -100,7 +102,8 @@ public class StripePaymentService : IPaymentService if (sub.Status == SubscriptionStatuses.Canceled) { - throw new BadRequestException("You do not have an active subscription. Reinstate your subscription to make changes."); + throw new BadRequestException( + "You do not have an active subscription. Reinstate your subscription to make changes."); } var existingCoupon = sub.Customer.Discount?.Coupon?.Id; @@ -191,24 +194,24 @@ public class StripePaymentService : IPaymentService throw; } } - else if (!invoice.Paid) + else if (invoice.Status != StripeConstants.InvoiceStatus.Paid) { // Pay invoice with no charge to the customer this completes the invoice immediately without waiting the scheduled 1h invoice = await _stripeAdapter.InvoicePayAsync(subResponse.LatestInvoiceId); paymentIntentClientSecret = null; } - } finally { // Change back the subscription collection method and/or days until due if (collectionMethod != "send_invoice" || daysUntilDue == null) { - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions - { - CollectionMethod = collectionMethod, - DaysUntilDue = daysUntilDue, - }); + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new SubscriptionUpdateOptions + { + CollectionMethod = collectionMethod, + DaysUntilDue = daysUntilDue, + }); } var customer = await _stripeAdapter.CustomerGetAsync(sub.CustomerId); @@ -218,9 +221,15 @@ public class StripePaymentService : IPaymentService if (!string.IsNullOrEmpty(existingCoupon) && string.IsNullOrEmpty(newCoupon)) { // Re-add the lost coupon due to the update. - await _stripeAdapter.CustomerUpdateAsync(sub.CustomerId, new CustomerUpdateOptions + await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, new SubscriptionUpdateOptions { - Coupon = existingCoupon + Discounts = + [ + new SubscriptionDiscountOptions + { + Coupon = existingCoupon + } + ] }); } } @@ -352,7 +361,7 @@ public class StripePaymentService : IPaymentService { var hasDefaultCardPaymentMethod = customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card"; var hasDefaultValidSource = customer.DefaultSource != null && - (customer.DefaultSource is Card || customer.DefaultSource is BankAccount); + (customer.DefaultSource is Card || customer.DefaultSource is BankAccount); if (!hasDefaultCardPaymentMethod && !hasDefaultValidSource) { cardPaymentMethodId = GetLatestCardPaymentMethod(customer.Id)?.Id; @@ -365,12 +374,11 @@ public class StripePaymentService : IPaymentService } catch { - await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions - { - AutoAdvance = false - }); + await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false }); await _stripeAdapter.InvoiceVoidInvoiceAsync(invoice.Id); } + throw new BadRequestException("No payment method is available."); } } @@ -381,14 +389,9 @@ public class StripePaymentService : IPaymentService { // Finalize the invoice (from Draft) w/o auto-advance so we // can attempt payment manually. - invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, new InvoiceFinalizeOptions - { - AutoAdvance = false, - }); - var invoicePayOptions = new InvoicePayOptions - { - PaymentMethod = cardPaymentMethodId, - }; + invoice = await _stripeAdapter.InvoiceFinalizeInvoiceAsync(invoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false, }); + var invoicePayOptions = new InvoicePayOptions { PaymentMethod = cardPaymentMethodId, }; if (customer?.Metadata?.ContainsKey("btCustomerId") ?? false) { invoicePayOptions.PaidOutOfBand = true; @@ -403,13 +406,15 @@ public class StripePaymentService : IPaymentService SubmitForSettlement = true, PayPal = new Braintree.TransactionOptionsPayPalRequest { - CustomField = $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}" + CustomField = + $"{subscriber.BraintreeIdField()}:{subscriber.Id},{subscriber.BraintreeCloudRegionField()}:{_globalSettings.BaseServiceUri.CloudRegion}" } }, CustomFields = new Dictionary { [subscriber.BraintreeIdField()] = subscriber.Id.ToString(), - [subscriber.BraintreeCloudRegionField()] = _globalSettings.BaseServiceUri.CloudRegion + [subscriber.BraintreeCloudRegionField()] = + _globalSettings.BaseServiceUri.CloudRegion } }); @@ -442,9 +447,9 @@ public class StripePaymentService : IPaymentService { // SCA required, get intent client secret var invoiceGetOptions = new InvoiceGetOptions(); - invoiceGetOptions.AddExpand("payment_intent"); + invoiceGetOptions.AddExpand("confirmation_secret"); invoice = await _stripeAdapter.InvoiceGetAsync(invoice.Id, invoiceGetOptions); - paymentIntentClientSecret = invoice?.PaymentIntent?.ClientSecret; + paymentIntentClientSecret = invoice?.ConfirmationSecret?.ClientSecret; } else { @@ -458,6 +463,7 @@ public class StripePaymentService : IPaymentService { await _btGateway.Transaction.RefundAsync(braintreeTransaction.Id); } + if (invoice != null) { if (invoice.Status == "paid") @@ -479,10 +485,8 @@ public class StripePaymentService : IPaymentService // Assumption: Customer balance should now be $0, otherwise payment would not have failed. if (customer.Balance == 0) { - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Balance = invoice.StartingBalance - }); + await _stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { Balance = invoice.StartingBalance }); } } } @@ -496,6 +500,7 @@ public class StripePaymentService : IPaymentService // Let the caller perform any subscription change cleanup throw; } + return paymentIntentClientSecret; } @@ -526,10 +531,10 @@ public class StripePaymentService : IPaymentService try { - var canceledSub = endOfPeriod ? - await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, - new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) : - await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions()); + var canceledSub = endOfPeriod + ? await _stripeAdapter.SubscriptionUpdateAsync(sub.Id, + new SubscriptionUpdateOptions { CancelAtPeriodEnd = true }) + : await _stripeAdapter.SubscriptionCancelAsync(sub.Id, new SubscriptionCancelOptions()); if (!canceledSub.CanceledAt.HasValue) { throw new GatewayException("Unable to cancel subscription."); @@ -580,7 +585,7 @@ public class StripePaymentService : IPaymentService { Customer customer = null; var customerExists = subscriber.Gateway == GatewayType.Stripe && - !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); + !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId); if (customerExists) { customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId); @@ -595,10 +600,10 @@ public class StripePaymentService : IPaymentService subscriber.Gateway = GatewayType.Stripe; subscriber.GatewayCustomerId = customer.Id; } - await _stripeAdapter.CustomerUpdateAsync(customer.Id, new CustomerUpdateOptions - { - Balance = customer.Balance - (long)(creditAmount * 100) - }); + + await _stripeAdapter.CustomerUpdateAsync(customer.Id, + new CustomerUpdateOptions { Balance = customer.Balance - (long)(creditAmount * 100) }); + return !customerExists; } @@ -630,50 +635,57 @@ public class StripePaymentService : IPaymentService { var subscriptionInfo = new SubscriptionInfo(); - if (!string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) - { - var customerGetOptions = new CustomerGetOptions(); - customerGetOptions.AddExpand("discount.coupon.applies_to"); - var customer = await _stripeAdapter.CustomerGetAsync(subscriber.GatewayCustomerId, customerGetOptions); - - if (customer.Discount != null) - { - subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(customer.Discount); - } - } - - if (string.IsNullOrWhiteSpace(subscriber.GatewaySubscriptionId)) + if (string.IsNullOrEmpty(subscriber.GatewaySubscriptionId)) { return subscriptionInfo; } - var sub = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, new SubscriptionGetOptions + var subscription = await _stripeAdapter.SubscriptionGetAsync(subscriber.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer.discount.coupon.applies_to", "discounts.coupon.applies_to", "test_clock"] }); + + if (subscription == null) { - Expand = ["test_clock"] - }); - - if (sub != null) - { - subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(sub); - - var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(sub); - - if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) - { - subscriptionInfo.Subscription.SuspensionDate = suspensionDate; - subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; - } + return subscriptionInfo; } - if (sub is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + subscriptionInfo.Subscription = new SubscriptionInfo.BillingSubscription(subscription); + + // Discount selection priority: + // 1. Customer-level discount (applies to all subscriptions for the customer) + // 2. First subscription-level discount (if multiple exist, FirstOrDefault() selects the first one) + // Note: When multiple subscription-level discounts exist, only the first one is used. + // This matches Stripe's behavior where the first discount in the list is applied. + // Defensive null checks: Even though we expand "customer" and "discounts", external APIs + // may not always return the expected data structure, so we use null-safe operators. + var discount = subscription.Customer?.Discount ?? subscription.Discounts?.FirstOrDefault(); + + if (discount != null) + { + subscriptionInfo.CustomerDiscount = new SubscriptionInfo.BillingCustomerDiscount(discount); + } + + var (suspensionDate, unpaidPeriodEndDate) = await GetSuspensionDateAsync(subscription); + + if (suspensionDate.HasValue && unpaidPeriodEndDate.HasValue) + { + subscriptionInfo.Subscription.SuspensionDate = suspensionDate; + subscriptionInfo.Subscription.UnpaidPeriodEndDate = unpaidPeriodEndDate; + } + + if (subscription is { CanceledAt: not null } || string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { return subscriptionInfo; } try { - var upcomingInvoiceOptions = new UpcomingInvoiceOptions { Customer = subscriber.GatewayCustomerId }; - var upcomingInvoice = await _stripeAdapter.InvoiceUpcomingAsync(upcomingInvoiceOptions); + var invoiceCreatePreviewOptions = new InvoiceCreatePreviewOptions + { + Customer = subscriber.GatewayCustomerId, + Subscription = subscriber.GatewaySubscriptionId + }; + + var upcomingInvoice = await _stripeAdapter.InvoiceCreatePreviewAsync(invoiceCreatePreviewOptions); if (upcomingInvoice != null) { @@ -682,7 +694,12 @@ public class StripePaymentService : IPaymentService } catch (StripeException ex) { - _logger.LogWarning(ex, "Encountered an unexpected Stripe error"); + _logger.LogWarning( + ex, + "Failed to retrieve upcoming invoice for customer {CustomerId}, subscription {SubscriptionId}. Error Code: {ErrorCode}", + subscriber.GatewayCustomerId, + subscriber.GatewaySubscriptionId, + ex.StripeError?.Code); } return subscriptionInfo; @@ -788,7 +805,11 @@ public class StripePaymentService : IPaymentService if (taxInfo.TaxIdType == StripeConstants.TaxIdType.SpanishNIF) { await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = StripeConstants.TaxIdType.EUVAT, Value = $"ES{taxInfo.TaxIdNumber}" }); + new TaxIdCreateOptions + { + Type = StripeConstants.TaxIdType.EUVAT, + Value = $"ES{taxInfo.TaxIdNumber}" + }); } } catch (StripeException e) @@ -829,7 +850,8 @@ public class StripePaymentService : IPaymentService await HasSecretsManagerStandaloneAsync(gatewayCustomerId: organization.GatewayCustomerId, organizationHasSecretsManager: organization.UseSecretsManager); - private async Task HasSecretsManagerStandaloneAsync(string gatewayCustomerId, bool organizationHasSecretsManager) + private async Task HasSecretsManagerStandaloneAsync(string gatewayCustomerId, + bool organizationHasSecretsManager) { if (string.IsNullOrEmpty(gatewayCustomerId)) { @@ -887,32 +909,32 @@ public class StripePaymentService : IPaymentService } } + [Obsolete($"Use {nameof(PreviewPremiumTaxCommand)} instead.")] public async Task PreviewInvoiceAsync( PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId) { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + var options = new InvoiceCreatePreviewOptions { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true, }, Currency = "usd", SubscriptionDetails = new InvoiceSubscriptionDetailsOptions { Items = [ - new() + new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, - Plan = StripeConstants.Prices.PremiumAnnually + Plan = premiumPlan.Seat.StripePriceId }, - new() + new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = "storage-gb-annually" + Plan = premiumPlan.Storage.StripePriceId } ] }, @@ -940,12 +962,9 @@ public class StripePaymentService : IPaymentService throw new BadRequestException("billingPreviewInvalidTaxIdError"); } - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } + options.CustomerDetails.TaxIds = + [ + new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId } ]; if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) @@ -964,7 +983,7 @@ public class StripePaymentService : IPaymentService if (gatewayCustomer.Discount != null) { - options.Coupon = gatewayCustomer.Discount.Coupon.Id; + options.Discounts = [new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id }]; } } @@ -972,24 +991,31 @@ public class StripePaymentService : IPaymentService { var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - if (gatewaySubscription?.Discount != null) + if (gatewaySubscription?.Discounts is { Count: > 0 }) { - options.Coupon ??= gatewaySubscription.Discount.Coupon.Id; + options.Discounts = gatewaySubscription.Discounts.Select(x => new InvoiceDiscountOptions { Coupon = x.Coupon.Id }).ToList(); } } + if (options.Discounts is { Count: > 0 }) + { + options.Discounts = options.Discounts.DistinctBy(invoiceDiscountOptions => invoiceDiscountOptions.Coupon).ToList(); + } + try { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() + var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount); + + var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 + ? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() : 0M; var result = new PreviewInvoiceResponseModel( effectiveTaxRate, invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, + tax.ToMajor(), invoice.Total.ToMajor()); return result; } @@ -1003,7 +1029,8 @@ public class StripePaymentService : IPaymentService parameters.TaxInformation.Country); throw new BadRequestException("billingPreviewInvalidTaxIdError"); default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", + _logger.LogError(e, + "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", parameters.TaxInformation.TaxId, parameters.TaxInformation.Country); throw new BadRequestException("billingPreviewInvoiceError"); @@ -1026,7 +1053,7 @@ public class StripePaymentService : IPaymentService { Items = [ - new() + new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.AdditionalStorage, Plan = plan.PasswordManager.StripeStoragePlanId @@ -1047,7 +1074,7 @@ public class StripePaymentService : IPaymentService { var sponsoredPlan = Utilities.StaticStore.GetSponsoredPlan(parameters.PasswordManager.SponsoredPlan.Value); options.SubscriptionDetails.Items.Add( - new() { Quantity = 1, Plan = sponsoredPlan.StripePlanId } + new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = sponsoredPlan.StripePlanId } ); } else @@ -1055,13 +1082,13 @@ public class StripePaymentService : IPaymentService if (plan.PasswordManager.HasAdditionalSeatsOption) { options.SubscriptionDetails.Items.Add( - new() { Quantity = parameters.PasswordManager.Seats, Plan = plan.PasswordManager.StripeSeatPlanId } + new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.PasswordManager.Seats, Plan = plan.PasswordManager.StripeSeatPlanId } ); } else { options.SubscriptionDetails.Items.Add( - new() { Quantity = 1, Plan = plan.PasswordManager.StripePlanId } + new InvoiceSubscriptionDetailsItemOptions { Quantity = 1, Plan = plan.PasswordManager.StripePlanId } ); } @@ -1069,7 +1096,7 @@ public class StripePaymentService : IPaymentService { if (plan.SecretsManager.HasAdditionalSeatsOption) { - options.SubscriptionDetails.Items.Add(new() + options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.SecretsManager?.Seats ?? 0, Plan = plan.SecretsManager.StripeSeatPlanId @@ -1078,7 +1105,7 @@ public class StripePaymentService : IPaymentService if (plan.SecretsManager.HasAdditionalServiceAccountOption) { - options.SubscriptionDetails.Items.Add(new() + options.SubscriptionDetails.Items.Add(new InvoiceSubscriptionDetailsItemOptions { Quantity = parameters.SecretsManager?.AdditionalMachineAccounts ?? 0, Plan = plan.SecretsManager.StripeServiceAccountPlanId @@ -1101,12 +1128,9 @@ public class StripePaymentService : IPaymentService throw new BadRequestException("billingTaxIdTypeInferenceError"); } - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } + options.CustomerDetails.TaxIds = + [ + new InvoiceCustomerDetailsTaxIdOptions { Type = taxIdType, Value = parameters.TaxInformation.TaxId } ]; if (taxIdType == StripeConstants.TaxIdType.SpanishNIF) @@ -1127,7 +1151,10 @@ public class StripePaymentService : IPaymentService if (gatewayCustomer.Discount != null) { - options.Coupon = gatewayCustomer.Discount.Coupon.Id; + options.Discounts = + [ + new InvoiceDiscountOptions { Coupon = gatewayCustomer.Discount.Coupon.Id } + ]; } } @@ -1135,9 +1162,10 @@ public class StripePaymentService : IPaymentService { var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - if (gatewaySubscription?.Discount != null) + if (gatewaySubscription?.Discounts != null) { - options.Coupon ??= gatewaySubscription.Discount.Coupon.Id; + options.Discounts = gatewaySubscription.Discounts + .Select(discount => new InvoiceDiscountOptions { Coupon = discount.Coupon.Id }).ToList(); } } @@ -1152,14 +1180,16 @@ public class StripePaymentService : IPaymentService { var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() + var tax = invoice.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount); + + var effectiveTaxRate = invoice.TotalExcludingTax != null && invoice.TotalExcludingTax.Value != 0 + ? tax.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() : 0M; var result = new PreviewInvoiceResponseModel( effectiveTaxRate, invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, + tax.ToMajor(), invoice.Total.ToMajor()); return result; } @@ -1173,7 +1203,8 @@ public class StripePaymentService : IPaymentService parameters.TaxInformation.Country); throw new BadRequestException("billingPreviewInvalidTaxIdError"); default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", + _logger.LogError(e, + "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", parameters.TaxInformation.TaxId, parameters.TaxInformation.Country); throw new BadRequestException("billingPreviewInvoiceError"); @@ -1207,7 +1238,9 @@ public class StripePaymentService : IPaymentService braintreeCustomer.DefaultPaymentMethod); } } - catch (Braintree.Exceptions.NotFoundException) { } + catch (Braintree.Exceptions.NotFoundException) + { + } } if (customer.InvoiceSettings?.DefaultPaymentMethod?.Type == "card") @@ -1246,12 +1279,15 @@ public class StripePaymentService : IPaymentService { customer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId, options); } - catch (StripeException) { } + catch (StripeException) + { + } return customer; } - private async Task> GetBillingTransactionsAsync(ISubscriber subscriber, int? limit = null) + private async Task> GetBillingTransactionsAsync( + ISubscriber subscriber, int? limit = null) { var transactions = subscriber switch { diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index a36b9e37cc..daf1b2078d 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -14,10 +14,10 @@ using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Models; using Bit.Core.Billing.Models.Business; using Bit.Core.Billing.Models.Sales; +using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Billing.Tax.Models; using Bit.Core.Context; @@ -72,6 +72,7 @@ public class UserService : UserManager, IUserService private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IDistributedCache _distributedCache; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IPricingClient _pricingClient; public UserService( IUserRepository userRepository, @@ -106,7 +107,8 @@ public class UserService : UserManager, IUserService IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IDistributedCache distributedCache, - IPolicyRequirementQuery policyRequirementQuery) + IPolicyRequirementQuery policyRequirementQuery, + IPricingClient pricingClient) : base( store, optionsAccessor, @@ -146,6 +148,7 @@ public class UserService : UserManager, IUserService _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _distributedCache = distributedCache; _policyRequirementQuery = policyRequirementQuery; + _pricingClient = pricingClient; } public Guid? GetProperUserId(ClaimsPrincipal principal) @@ -972,8 +975,9 @@ public class UserService : UserManager, IUserService throw new BadRequestException("Not a premium user."); } - var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, - StripeConstants.Prices.StoragePlanPersonal); + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + + var secret = await BillingHelpers.AdjustStorageAsync(_paymentService, user, storageAdjustmentGb, premiumPlan.Storage.StripePriceId); await SaveUserAsync(user); return secret; } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 546e668093..c467d1e652 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -8,6 +8,7 @@ namespace Bit.Core.Settings; public class GlobalSettings : IGlobalSettings { + private string _mailTemplateDirectory; private string _logDirectory; private string _licenseDirectory; @@ -37,6 +38,11 @@ public class GlobalSettings : IGlobalSettings get => BuildDirectory(_licenseDirectory, "/core/licenses"); set => _licenseDirectory = value; } + public virtual string MailTemplateDirectory + { + get => BuildDirectory(_mailTemplateDirectory, "/mail-templates"); + set => _mailTemplateDirectory = value; + } public string LicenseCertificatePassword { get; set; } public virtual string PushRelayBaseUri { get; set; } public virtual string InternalIdentityKey { get; set; } @@ -56,6 +62,7 @@ public class GlobalSettings : IGlobalSettings public virtual SqlSettings MySql { get; set; } = new SqlSettings(); public virtual SqlSettings Sqlite { get; set; } = new SqlSettings() { ConnectionString = "Data Source=:memory:" }; public virtual SlackSettings Slack { get; set; } = new SlackSettings(); + public virtual TeamsSettings Teams { get; set; } = new TeamsSettings(); public virtual EventLoggingSettings EventLogging { get; set; } = new EventLoggingSettings(); public virtual MailSettings Mail { get; set; } = new MailSettings(); public virtual IConnectionStringSettings Storage { get; set; } = new ConnectionStringSettings(); @@ -97,6 +104,7 @@ public class GlobalSettings : IGlobalSettings ///
public virtual string SendDefaultHashKey { get; set; } public virtual string PricingUri { get; set; } + public virtual Fido2Settings Fido2 { get; set; } = new Fido2Settings(); public string BuildExternalUri(string explicitValue, string name) { @@ -288,6 +296,15 @@ public class GlobalSettings : IGlobalSettings public virtual string Scopes { get; set; } } + public class TeamsSettings + { + public virtual string LoginBaseUrl { get; set; } = "https://login.microsoftonline.com"; + public virtual string GraphBaseUrl { get; set; } = "https://graph.microsoft.com/v1.0"; + public virtual string ClientId { get; set; } + public virtual string ClientSecret { get; set; } + public virtual string Scopes { get; set; } + } + public class EventLoggingSettings { public AzureServiceBusSettings AzureServiceBus { get; set; } = new AzureServiceBusSettings(); @@ -301,6 +318,9 @@ public class GlobalSettings : IGlobalSettings private string _eventTopicName; private string _integrationTopicName; + public virtual int DefaultMaxConcurrentCalls { get; set; } = 1; + public virtual int DefaultPrefetchCount { get; set; } = 0; + public virtual string EventRepositorySubscriptionName { get; set; } = "events-write-subscription"; public virtual string SlackEventSubscriptionName { get; set; } = "events-slack-subscription"; public virtual string SlackIntegrationSubscriptionName { get; set; } = "integration-slack-subscription"; @@ -310,6 +330,8 @@ public class GlobalSettings : IGlobalSettings public virtual string HecIntegrationSubscriptionName { get; set; } = "integration-hec-subscription"; public virtual string DatadogEventSubscriptionName { get; set; } = "events-datadog-subscription"; public virtual string DatadogIntegrationSubscriptionName { get; set; } = "integration-datadog-subscription"; + public virtual string TeamsEventSubscriptionName { get; set; } = "events-teams-subscription"; + public virtual string TeamsIntegrationSubscriptionName { get; set; } = "integration-teams-subscription"; public string ConnectionString { @@ -354,6 +376,9 @@ public class GlobalSettings : IGlobalSettings public virtual string DatadogEventsQueueName { get; set; } = "events-datadog-queue"; public virtual string DatadogIntegrationQueueName { get; set; } = "integration-datadog-queue"; public virtual string DatadogIntegrationRetryQueueName { get; set; } = "integration-datadog-retry-queue"; + public virtual string TeamsEventsQueueName { get; set; } = "events-teams-queue"; + public virtual string TeamsIntegrationQueueName { get; set; } = "integration-teams-queue"; + public virtual string TeamsIntegrationRetryQueueName { get; set; } = "integration-teams-retry-queue"; public string HostName { @@ -652,6 +677,7 @@ public class GlobalSettings : IGlobalSettings public bool Production { get; set; } public string Token { get; set; } public string NotificationUrl { get; set; } + public string WebhookKey { get; set; } } public class InstallationSettings : IInstallationSettings @@ -763,4 +789,9 @@ public class GlobalSettings : IGlobalSettings { public string VapidPublicKey { get; set; } } + + public class Fido2Settings + { + public HashSet Origins { get; set; } + } } diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index ce269bc68c..c7f7e3aff7 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -108,15 +108,7 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); - if (useBulkResourceCreationService) - { - await _cipherRepository.CreateAsync_vNext(importingUserId, ciphers, newFolders); - } - else - { - await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); - } + await _cipherRepository.CreateAsync(importingUserId, ciphers, newFolders); // push await _pushService.PushSyncVaultAsync(importingUserId); @@ -191,15 +183,7 @@ public class ImportCiphersCommand : IImportCiphersCommand } // Create it all - var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); - if (useBulkResourceCreationService) - { - await _cipherRepository.CreateAsync_vNext(ciphers, newCollections, collectionCiphers, newCollectionUsers); - } - else - { - await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); - } + await _cipherRepository.CreateAsync(ciphers, newCollections, collectionCiphers, newCollectionUsers); // push await _pushService.PushSyncVaultAsync(importingUserId); diff --git a/src/Core/Utilities/AssemblyHelpers.cs b/src/Core/Utilities/AssemblyHelpers.cs index 0cc01efdf3..03f7ff986d 100644 --- a/src/Core/Utilities/AssemblyHelpers.cs +++ b/src/Core/Utilities/AssemblyHelpers.cs @@ -1,46 +1,46 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - +using System.Diagnostics; using System.Reflection; namespace Bit.Core.Utilities; public static class AssemblyHelpers { - private static readonly IEnumerable _assemblyMetadataAttributes; - private static readonly AssemblyInformationalVersionAttribute _assemblyInformationalVersionAttributes; - private const string GIT_HASH_ASSEMBLY_KEY = "GitHash"; - private static string _version; - private static string _gitHash; + private static string? _version; + private static string? _gitHash; static AssemblyHelpers() { - _assemblyMetadataAttributes = Assembly.GetEntryAssembly().GetCustomAttributes(); - _assemblyInformationalVersionAttributes = Assembly.GetEntryAssembly().GetCustomAttribute(); - } - - public static string GetVersion() - { - if (string.IsNullOrWhiteSpace(_version)) + var assemblyInformationalVersionAttribute = typeof(AssemblyHelpers).Assembly.GetCustomAttribute(); + if (assemblyInformationalVersionAttribute == null) { - _version = _assemblyInformationalVersionAttributes.InformationalVersion; + Debug.Fail("The AssemblyInformationalVersionAttribute is expected to exist in this assembly, possibly its generation was turned off."); + return; } + var informationalVersion = assemblyInformationalVersionAttribute.InformationalVersion.AsSpan(); + + if (!informationalVersion.TrySplitBy('+', out var version, out var gitHash)) + { + // Treat the whole thing as the version + _version = informationalVersion.ToString(); + return; + } + + _version = version.ToString(); + if (gitHash.Length < 8) + { + return; + } + _gitHash = gitHash[..8].ToString(); + } + + public static string? GetVersion() + { return _version; } - public static string GetGitHash() + public static string? GetGitHash() { - if (string.IsNullOrWhiteSpace(_gitHash)) - { - try - { - _gitHash = _assemblyMetadataAttributes.Where(i => i.Key == GIT_HASH_ASSEMBLY_KEY).First().Value; - } - catch (Exception) - { } - } - return _gitHash; } } diff --git a/src/Core/Utilities/BitPayClient.cs b/src/Core/Utilities/BitPayClient.cs deleted file mode 100644 index cf241d5723..0000000000 --- a/src/Core/Utilities/BitPayClient.cs +++ /dev/null @@ -1,30 +0,0 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Settings; - -namespace Bit.Core.Utilities; - -public class BitPayClient -{ - private readonly BitPayLight.BitPay _bpClient; - - public BitPayClient(GlobalSettings globalSettings) - { - if (CoreHelpers.SettingHasValue(globalSettings.BitPay.Token)) - { - _bpClient = new BitPayLight.BitPay(globalSettings.BitPay.Token, - globalSettings.BitPay.Production ? BitPayLight.Env.Prod : BitPayLight.Env.Test); - } - } - - public Task GetInvoiceAsync(string id) - { - return _bpClient.GetInvoice(id); - } - - public Task CreateInvoiceAsync(BitPayLight.Models.Invoice.Invoice invoice) - { - return _bpClient.CreateInvoice(invoice); - } -} diff --git a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs index 6709bbb271..300c30641e 100644 --- a/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs +++ b/src/Core/Utilities/LoggingExceptionHandlerFilterAttribute.cs @@ -17,6 +17,6 @@ public class LoggingExceptionHandlerFilterAttribute : ExceptionFilterAttribute var logger = context.HttpContext.RequestServices .GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); } } diff --git a/src/Core/Utilities/StaticStore.cs b/src/Core/Utilities/StaticStore.cs index 1ddd926569..36c4a54ae4 100644 --- a/src/Core/Utilities/StaticStore.cs +++ b/src/Core/Utilities/StaticStore.cs @@ -137,6 +137,7 @@ public static class StaticStore new Teams2019Plan(true), new Teams2019Plan(false), new Families2019Plan(), + new Families2025Plan() }.ToImmutableList(); } diff --git a/src/Core/Vault/Repositories/ICipherRepository.cs b/src/Core/Vault/Repositories/ICipherRepository.cs index 32acf3cbc9..94518bae2a 100644 --- a/src/Core/Vault/Repositories/ICipherRepository.cs +++ b/src/Core/Vault/Repositories/ICipherRepository.cs @@ -33,28 +33,12 @@ public interface ICipherRepository : IRepository Task DeleteByUserIdAsync(Guid userId); Task DeleteByOrganizationIdAsync(Guid organizationId); Task UpdateCiphersAsync(Guid userId, IEnumerable ciphers); - /// - /// - /// This version uses the bulk resource creation service to create the temp table. - /// - Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers); /// /// Create ciphers and folders for the specified UserId. Must not be used to create organization owned items. /// Task CreateAsync(Guid userId, IEnumerable ciphers, IEnumerable folders); - /// - /// - /// This version uses the bulk resource creation service to create the temp tables. - /// - Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders); Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, IEnumerable collectionUsers); - /// - /// - /// This version uses the bulk resource creation service to create the temp tables. - /// - Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers, IEnumerable collectionUsers); Task SoftDeleteAsync(IEnumerable ids, Guid userId); Task SoftDeleteByIdsOrganizationIdAsync(IEnumerable ids, Guid organizationId); Task UnarchiveAsync(IEnumerable ids, Guid userId); @@ -92,10 +76,4 @@ public interface ICipherRepository : IRepository ///
Task> GetManyCipherOrganizationDetailsExcludingDefaultCollectionsAsync(Guid organizationId); - /// - /// - /// This version uses the bulk resource creation service to create the temp table. - /// - UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext(Guid userId, - IEnumerable ciphers); } diff --git a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs index 4b88f1c0e8..0be3bbd545 100644 --- a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs +++ b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs @@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository /// The id of the organization /// A collection of security task metrics Task GetTaskMetricsAsync(Guid organizationId); + + /// + /// Marks all tasks associated with the respective ciphers as complete. + /// + /// Collection of cipher IDs + Task MarkAsCompleteByCipherIds(IEnumerable cipherIds); } diff --git a/src/Core/Vault/Services/ICipherService.cs b/src/Core/Vault/Services/ICipherService.cs index dac535433c..110d4b6ea4 100644 --- a/src/Core/Vault/Services/ICipherService.cs +++ b/src/Core/Vault/Services/ICipherService.cs @@ -13,11 +13,11 @@ public interface ICipherService Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, bool skipPermissionCheck = false); Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false); + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, - string attachmentId, Guid organizationShareId); + string attachmentId, Guid organizationShareId, DateTime? lastKnownRevisionDate = null); Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); @@ -34,7 +34,8 @@ public interface ICipherService Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false); Task> RestoreManyAsync(IEnumerable cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false); - Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); + Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId, DateTime? lastKnownRevisionDate = null); Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); + Task ValidateBulkCollectionAssignmentAsync(IEnumerable collectionIds, IEnumerable cipherIds, Guid userId); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index ebfb2a4a2a..4e980f66b6 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -33,6 +33,7 @@ public class CipherService : ICipherService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly ISecurityTaskRepository _securityTaskRepository; private readonly IPushNotificationService _pushService; private readonly IAttachmentStorageService _attachmentStorageService; private readonly IEventService _eventService; @@ -53,6 +54,7 @@ public class CipherService : ICipherService IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, ICollectionCipherRepository collectionCipherRepository, + ISecurityTaskRepository securityTaskRepository, IPushNotificationService pushService, IAttachmentStorageService attachmentStorageService, IEventService eventService, @@ -71,6 +73,7 @@ public class CipherService : ICipherService _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _collectionCipherRepository = collectionCipherRepository; + _securityTaskRepository = securityTaskRepository; _pushService = pushService; _attachmentStorageService = attachmentStorageService; _eventService = eventService; @@ -113,7 +116,7 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; await _cipherRepository.ReplaceAsync(cipher); await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated); @@ -168,8 +171,9 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; + await ValidateChangeInCollectionsAsync(cipher, collectionIds, savingUserId); await ValidateViewPasswordUserAsync(cipher); await _cipherRepository.ReplaceAsync(cipher); await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated); @@ -179,8 +183,9 @@ public class CipherService : ICipherService } } - public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) + public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (attachment == null) { throw new BadRequestException("Cipher attachment does not exist"); @@ -195,8 +200,9 @@ public class CipherService : ICipherService } public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -231,8 +237,9 @@ public class CipherService : ICipherService } public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false) + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -283,10 +290,11 @@ public class CipherService : ICipherService } public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, string attachmentId, Guid organizationId) + long requestLength, string attachmentId, Guid organizationId, DateTime? lastKnownRevisionDate = null) { try { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (requestLength < 1) { throw new BadRequestException("No data to attach."); @@ -539,6 +547,7 @@ public class CipherService : ICipherService try { await ValidateCipherCanBeShared(cipher, sharingUserId, organizationId, lastKnownRevisionDate); + await ValidateChangeInCollectionsAsync(cipher, collectionIds, sharingUserId); // Sproc will not save this UserId on the cipher. It is used limit scope of the collectionIds. cipher.UserId = sharingUserId; @@ -642,15 +651,7 @@ public class CipherService : ICipherService cipherIds.Add(cipher.Id); } - var useBulkResourceCreationService = _featureService.IsEnabled(FeatureFlagKeys.CipherRepositoryBulkResourceCreation); - if (useBulkResourceCreationService) - { - await _cipherRepository.UpdateCiphersAsync_vNext(sharingUserId, cipherInfos.Select(c => c.cipher)); - } - else - { - await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); - } + await _cipherRepository.UpdateCiphersAsync(sharingUserId, cipherInfos.Select(c => c.cipher)); await _collectionCipherRepository.UpdateCollectionsForCiphersAsync(cipherIds, sharingUserId, organizationId, collectionIds); @@ -678,6 +679,7 @@ public class CipherService : ICipherService { throw new BadRequestException("Cipher must belong to an organization."); } + await ValidateChangeInCollectionsAsync(cipher, collectionIds, savingUserId); cipher.RevisionDate = DateTime.UtcNow; @@ -718,6 +720,14 @@ public class CipherService : ICipherService cipherDetails.DeletedDate = cipherDetails.RevisionDate = DateTime.UtcNow; + if (cipherDetails.ArchivedDate.HasValue) + { + // If the cipher was archived, clear the archived date when soft deleting + // If a user were to restore an archived cipher, it should go back to the vault not the archive vault + cipherDetails.ArchivedDate = null; + } + + await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]); await _cipherRepository.UpsertAsync(cipherDetails); await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); @@ -744,6 +754,8 @@ public class CipherService : ICipherService await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); } + await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id)); + var events = deletingCiphers.Select(c => new Tuple(c, EventType.Cipher_SoftDeleted, null)); foreach (var eventsBatch in events.Chunk(100)) @@ -820,6 +832,15 @@ public class CipherService : ICipherService return restoringCiphers; } + public async Task ValidateBulkCollectionAssignmentAsync(IEnumerable collectionIds, IEnumerable cipherIds, Guid userId) + { + foreach (var cipherId in cipherIds) + { + var cipher = await _cipherRepository.GetByIdAsync(cipherId); + await ValidateChangeInCollectionsAsync(cipher, collectionIds, userId); + } + } + private async Task UserCanEditAsync(Cipher cipher, Guid userId) { if (!cipher.OrganizationId.HasValue && cipher.UserId.HasValue && cipher.UserId.Value == userId) @@ -848,7 +869,7 @@ public class CipherService : ICipherService return NormalCipherPermissions.CanRestore(user, cipher, organizationAbility); } - private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) + private void ValidateCipherLastKnownRevisionDate(Cipher cipher, DateTime? lastKnownRevisionDate) { if (cipher.Id == default || !lastKnownRevisionDate.HasValue) { @@ -996,7 +1017,7 @@ public class CipherService : ICipherService throw new BadRequestException("Not enough storage available for this organization."); } - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); } private async Task ValidateViewPasswordUserAsync(Cipher cipher) @@ -1038,6 +1059,44 @@ public class CipherService : ICipherService } } + // Validates that a cipher is not being added to a default collection when it is only currently only in shared collections + private async Task ValidateChangeInCollectionsAsync(Cipher updatedCipher, IEnumerable newCollectionIds, Guid userId) + { + + if (updatedCipher.Id == Guid.Empty || !updatedCipher.OrganizationId.HasValue) + { + return; + } + + var currentCollectionsForCipher = await _collectionCipherRepository.GetManyByUserIdCipherIdAsync(userId, updatedCipher.Id); + + if (!currentCollectionsForCipher.Any()) + { + // When a cipher is not currently in any collections it can be assigned to any type of collection + return; + } + + var currentCollections = await _collectionRepository.GetManyByManyIdsAsync(currentCollectionsForCipher.Select(c => c.CollectionId)); + + var currentCollectionsContainDefault = currentCollections.Any(c => c.Type == CollectionType.DefaultUserCollection); + + // When the current cipher already contains the default collection, no check is needed for if they added or removed + // a default collection, because it is already there. + if (currentCollectionsContainDefault) + { + return; + } + + var newCollections = await _collectionRepository.GetManyByManyIdsAsync(newCollectionIds); + var newCollectionsContainDefault = newCollections.Any(c => c.Type == CollectionType.DefaultUserCollection); + + if (newCollectionsContainDefault) + { + // User is trying to add the default collection when the cipher is only in shared collections + throw new BadRequestException("The cipher(s) cannot be assigned to a default collection when only assigned to non-default collections."); + } + } + private string SerializeCipherData(CipherData data) { return data switch diff --git a/src/Events/Controllers/CollectController.cs b/src/Events/Controllers/CollectController.cs index d7fbbbc595..bae1575134 100644 --- a/src/Events/Controllers/CollectController.cs +++ b/src/Events/Controllers/CollectController.cs @@ -21,23 +21,17 @@ public class CollectController : Controller private readonly IEventService _eventService; private readonly ICipherRepository _cipherRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IFeatureService _featureService; - private readonly IApplicationCacheService _applicationCacheService; public CollectController( ICurrentContext currentContext, IEventService eventService, ICipherRepository cipherRepository, - IOrganizationRepository organizationRepository, - IFeatureService featureService, - IApplicationCacheService applicationCacheService) + IOrganizationRepository organizationRepository) { _currentContext = currentContext; _eventService = eventService; _cipherRepository = cipherRepository; _organizationRepository = organizationRepository; - _featureService = featureService; - _applicationCacheService = applicationCacheService; } [HttpPost] @@ -47,8 +41,10 @@ public class CollectController : Controller { return new BadRequestResult(); } + var cipherEvents = new List>(); var ciphersCache = new Dictionary(); + foreach (var eventModel in model) { switch (eventModel.Type) @@ -57,6 +53,7 @@ public class CollectController : Controller case EventType.User_ClientExportedVault: await _eventService.LogUserEventAsync(_currentContext.UserId.Value, eventModel.Type, eventModel.Date); break; + // Cipher events case EventType.Cipher_ClientAutofilled: case EventType.Cipher_ClientCopiedHiddenField: @@ -71,7 +68,8 @@ public class CollectController : Controller { continue; } - Cipher cipher = null; + + Cipher cipher; if (ciphersCache.TryGetValue(eventModel.CipherId.Value, out var cachedCipher)) { cipher = cachedCipher; @@ -81,6 +79,7 @@ public class CollectController : Controller cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value, _currentContext.UserId.Value); } + if (cipher == null) { // When the user cannot access the cipher directly, check if the organization allows for @@ -91,29 +90,44 @@ public class CollectController : Controller } cipher = await _cipherRepository.GetByIdAsync(eventModel.CipherId.Value); + if (cipher == null) + { + continue; + } + var cipherBelongsToOrg = cipher.OrganizationId == eventModel.OrganizationId; var org = _currentContext.GetOrganization(eventModel.OrganizationId.Value); - if (!cipherBelongsToOrg || org == null || cipher == null) + if (!cipherBelongsToOrg || org == null) { continue; } } + ciphersCache.TryAdd(eventModel.CipherId.Value, cipher); cipherEvents.Add(new Tuple(cipher, eventModel.Type, eventModel.Date)); break; + case EventType.Organization_ClientExportedVault: if (!eventModel.OrganizationId.HasValue) { continue; } + var organization = await _organizationRepository.GetByIdAsync(eventModel.OrganizationId.Value); + if (organization == null) + { + continue; + } + await _eventService.LogOrganizationEventAsync(organization, eventModel.Type, eventModel.Date); break; + default: continue; } } + if (cipherEvents.Any()) { foreach (var eventsBatch in cipherEvents.Chunk(50)) @@ -121,6 +135,7 @@ public class CollectController : Controller await _eventService.LogCipherEventsAsync(eventsBatch); } } + return new OkResult(); } } diff --git a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs index a709a47cb2..e16c8ad695 100644 --- a/src/Identity/IdentityServer/CustomValidatorRequestContext.cs +++ b/src/Identity/IdentityServer/CustomValidatorRequestContext.cs @@ -27,6 +27,12 @@ public class CustomValidatorRequestContext ///
public bool TwoFactorRequired { get; set; } = false; /// + /// Whether the user has requested recovery of their 2FA methods using their one-time + /// recovery code. + /// + /// + public bool TwoFactorRecoveryRequested { get; set; } = false; + /// /// This communicates whether or not SSO is required for the user to authenticate. /// public bool SsoRequired { get; set; } = false; @@ -42,10 +48,13 @@ public class CustomValidatorRequestContext /// This will be null if the authentication request is successful. ///
public Dictionary CustomResponse { get; set; } - /// /// A validated auth request /// /// public AuthRequest ValidatedAuthRequest { get; set; } + /// + /// Whether the user has requested a Remember Me token for their current device. + /// + public bool RememberMeRequested { get; set; } = false; } diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index e57ed1c85f..224c7a1866 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -1,4 +1,5 @@ // FIXME: Update this file to be null safe and then delete the line below + #nullable disable using System.Security.Claims; @@ -14,6 +15,8 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Api.Response; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Models.Api; using Bit.Core.Models.Api.Response; using Bit.Core.Repositories; @@ -45,6 +48,7 @@ public abstract class BaseRequestValidator where T : class protected IUserService _userService { get; } protected IUserDecryptionOptionsBuilder UserDecryptionOptionsBuilder { get; } protected IPolicyRequirementQuery PolicyRequirementQuery { get; } + protected IUserAccountKeysQuery _accountKeysQuery { get; } public BaseRequestValidator( UserManager userManager, @@ -63,8 +67,9 @@ public abstract class BaseRequestValidator where T : class IUserDecryptionOptionsBuilder userDecryptionOptionsBuilder, IPolicyRequirementQuery policyRequirementQuery, IAuthRequestRepository authRequestRepository, - IMailService mailService - ) + IMailService mailService, + IUserAccountKeysQuery userAccountKeysQuery + ) { _userManager = userManager; _userService = userService; @@ -83,130 +88,147 @@ public abstract class BaseRequestValidator where T : class PolicyRequirementQuery = policyRequirementQuery; _authRequestRepository = authRequestRepository; _mailService = mailService; + _accountKeysQuery = userAccountKeysQuery; } protected async Task ValidateAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // 1. We need to check if the user's master password hash is correct. - var valid = await ValidateContextAsync(context, validatorContext); - var user = validatorContext.User; - if (!valid) + if (FeatureService.IsEnabled(FeatureFlagKeys.RecoveryCodeSupportForSsoRequiredUsers)) { - await UpdateFailedAuthDetailsAsync(user); - - await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); - return; - } - - // 2. Decide if this user belongs to an organization that requires SSO. - validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); - if (validatorContext.SsoRequired) - { - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return; - } - - // 3. Check if 2FA is required. - (validatorContext.TwoFactorRequired, var twoFactorOrganization) = - await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); - - // This flag is used to determine if the user wants a rememberMe token sent when - // authentication is successful. - var returnRememberMeToken = false; - - if (validatorContext.TwoFactorRequired) - { - var twoFactorToken = request.Raw["TwoFactorToken"]; - var twoFactorProvider = request.Raw["TwoFactorProvider"]; - var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && - !string.IsNullOrWhiteSpace(twoFactorProvider); - - // 3a. Response for 2FA required and not provided state. - if (!validTwoFactorRequest || - !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) + var validators = DetermineValidationOrder(context, request, validatorContext); + var allValidationSchemesSuccessful = await ProcessValidatorsAsync(validators); + if (!allValidationSchemesSuccessful) { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - if (resultDict == null) + // Each validation task is responsible for setting its own non-success status, if applicable. + return; + } + await BuildSuccessResultAsync(validatorContext.User, context, validatorContext.Device, + validatorContext.RememberMeRequested); + } + else + { + // 1. We need to check if the user's master password hash is correct. + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (!valid) + { + await UpdateFailedAuthDetailsAsync(user); + + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return; + } + + // 2. Decide if this user belongs to an organization that requires SSO. + validatorContext.SsoRequired = await RequireSsoLoginAsync(user, request.GrantType); + if (validatorContext.SsoRequired) + { + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return; + } + + // 3. Check if 2FA is required. + (validatorContext.TwoFactorRequired, var twoFactorOrganization) = + await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(user, request); + + // This flag is used to determine if the user wants a rememberMe token sent when + // authentication is successful. + var returnRememberMeToken = false; + + if (validatorContext.TwoFactorRequired) + { + var twoFactorToken = request.Raw["TwoFactorToken"]; + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + // 3a. Response for 2FA required and not provided state. + if (!validTwoFactorRequest || + !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) { - await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(user, twoFactorOrganization); + if (resultDict == null) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, user); + return; + } + + // Include Master Password Policy in 2FA response. + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); + SetTwoFactorResult(context, resultDict); return; } - // Include Master Password Policy in 2FA response. - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); + var twoFactorTokenValid = + await _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); + + // 3b. Response for 2FA required but request is not valid or remember token expired state. + if (!twoFactorTokenValid) + { + // The remember me token has expired. + if (twoFactorProviderType == TwoFactorProviderType.Remember) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(user, twoFactorOrganization); + + // Include Master Password Policy in 2FA response + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); + SetTwoFactorResult(context, resultDict); + } + else + { + await SendFailedTwoFactorEmail(user, twoFactorProviderType); + await UpdateFailedAuthDetailsAsync(user); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); + } + + return; + } + + // 3c. When the 2FA authentication is successful, we can check if the user wants a + // rememberMe token. + var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; + // Check if the user wants a rememberMe token. + if (twoFactorRemember + // if the 2FA auth was rememberMe do not send another token. + && twoFactorProviderType != TwoFactorProviderType.Remember) + { + returnRememberMeToken = true; + } + } + + // 4. Check if the user is logging in from a new device. + var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); + if (!deviceValid) + { + SetValidationErrorResult(context, validatorContext); + await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); return; } - var twoFactorTokenValid = - await _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(user, twoFactorOrganization, twoFactorProviderType, twoFactorToken); - - // 3b. Response for 2FA required but request is not valid or remember token expired state. - if (!twoFactorTokenValid) + // 5. Force legacy users to the web for migration. + if (UserService.IsLegacyUser(user) && request.ClientId != "web") { - // The remember me token has expired. - if (twoFactorProviderType == TwoFactorProviderType.Remember) - { - var resultDict = await _twoFactorAuthenticationValidator - .BuildTwoFactorResultAsync(user, twoFactorOrganization); - - // Include Master Password Policy in 2FA response - resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(user)); - SetTwoFactorResult(context, resultDict); - } - else - { - await SendFailedTwoFactorEmail(user, twoFactorProviderType); - await UpdateFailedAuthDetailsAsync(user); - await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, user); - } + await FailAuthForLegacyUserAsync(user, context); return; } - // 3c. When the 2FA authentication is successful, we can check if the user wants a - // rememberMe token. - var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; - // Check if the user wants a rememberMe token. - if (twoFactorRemember - // if the 2FA auth was rememberMe do not send another token. - && twoFactorProviderType != TwoFactorProviderType.Remember) + // TODO: PM-24324 - This should be its own validator at some point. + // 6. Auth request handling + if (validatorContext.ValidatedAuthRequest != null) { - returnRememberMeToken = true; + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); } - } - // 4. Check if the user is logging in from a new device. - var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); - if (!deviceValid) - { - SetValidationErrorResult(context, validatorContext); - await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); - return; + await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); } - - // 5. Force legacy users to the web for migration. - if (UserService.IsLegacyUser(user) && request.ClientId != "web") - { - await FailAuthForLegacyUserAsync(user, context); - return; - } - - // TODO: PM-24324 - This should be its own validator at some point. - // 6. Auth request handling - if (validatorContext.ValidatedAuthRequest != null) - { - validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; - await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); - } - - await BuildSuccessResultAsync(user, context, validatorContext.Device, returnRememberMeToken); } protected async Task FailAuthForLegacyUserAsync(User user, T context) @@ -218,6 +240,302 @@ public abstract class BaseRequestValidator where T : class protected abstract Task ValidateContextAsync(T context, CustomValidatorRequestContext validatorContext); + /// + /// Composer for validation schemes. + /// + /// The current request context. + /// + /// + /// A composed array of validation scheme delegates to evaluate in order. + private Func>[] DetermineValidationOrder(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (RecoveryCodeRequestForSsoRequiredUserScenario()) + { + // Support valid requests to recover 2FA (with account code) for users who require SSO + // by organization membership. + // This requires an evaluation of 2FA validity in front of SSO, and an opportunity for the 2FA + // validation to perform the recovery as part of scheme validation based on the request. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + else + { + // The typical validation scenario. + return + [ + () => ValidateMasterPasswordAsync(context, validatorContext), + () => ValidateSsoAsync(context, request, validatorContext), + () => ValidateTwoFactorAsync(context, request, validatorContext), + () => ValidateNewDeviceAsync(context, request, validatorContext), + () => ValidateLegacyMigrationAsync(context, request, validatorContext), + () => ValidateAuthRequestAsync(validatorContext) + ]; + } + + bool RecoveryCodeRequestForSsoRequiredUserScenario() + { + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var twoFactorToken = request.Raw["TwoFactorToken"]; + + // Both provider and token must be present; + // Validity of the token for a given provider will be evaluated by the TwoFactorAuthenticationValidator. + if (string.IsNullOrWhiteSpace(twoFactorProvider) || string.IsNullOrWhiteSpace(twoFactorToken)) + { + return false; + } + + if (!int.TryParse(twoFactorProvider, out var providerValue)) + { + return false; + } + + return providerValue == (int)TwoFactorProviderType.RecoveryCode; + } + } + + /// + /// Processes the validation schemes sequentially. + /// Each validator is responsible for setting error context responses on failure and adding itself to the + /// validatorContext's CompletedValidationSchemes (only) on success. + /// Failure of any scheme to validate will short-circuit the collection, causing the validation error to be + /// returned and further schemes to not be evaluated. + /// + /// The collection of validation schemes as composed in + /// true if all schemes validated successfully, false if any failed. + private static async Task ProcessValidatorsAsync(params Func>[] validators) + { + foreach (var validator in validators) + { + if (!await validator()) + { + return false; + } + } + + return true; + } + + /// + /// Validates the user's Master Password hash. + /// + /// The current request context. + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateMasterPasswordAsync(T context, CustomValidatorRequestContext validatorContext) + { + var valid = await ValidateContextAsync(context, validatorContext); + var user = validatorContext.User; + if (valid) + { + return true; + } + + await UpdateFailedAuthDetailsAsync(user); + + await BuildErrorResultAsync("Username or password is incorrect. Try again.", false, context, user); + return false; + } + + /// + /// Validates the user's organization-enforced Single Sign-on (SSO) requirement. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + /// + private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); + if (!validatorContext.SsoRequired) + { + return true; + } + + // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are + // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and + // review their new recovery token if desired. + // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. + // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been + // evaluated, and recovery will have been performed if requested. + // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect + // to /login. + if (validatorContext.TwoFactorRequired && + validatorContext.TwoFactorRecoveryRequested) + { + SetSsoResult(context, new Dictionary + { + { "ErrorModel", new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") } + }); + return false; + } + + SetSsoResult(context, + new Dictionary + { + { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } + }); + return false; + } + + /// + /// Validates the user's Multi-Factor Authentication (2FA) scheme. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateTwoFactorAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + (validatorContext.TwoFactorRequired, var twoFactorOrganization) = + await _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(validatorContext.User, request); + + if (!validatorContext.TwoFactorRequired) + { + return true; + } + + var twoFactorToken = request.Raw["TwoFactorToken"]; + var twoFactorProvider = request.Raw["TwoFactorProvider"]; + var validTwoFactorRequest = !string.IsNullOrWhiteSpace(twoFactorToken) && + !string.IsNullOrWhiteSpace(twoFactorProvider); + + // 3a. Response for 2FA required and not provided state. + if (!validTwoFactorRequest || + !Enum.TryParse(twoFactorProvider, out TwoFactorProviderType twoFactorProviderType)) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + if (resultDict == null) + { + await BuildErrorResultAsync("No two-step providers enabled.", false, context, validatorContext.User); + return false; + } + + // Include Master Password Policy in 2FA response. + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + return false; + } + + var twoFactorTokenValid = + await _twoFactorAuthenticationValidator + .VerifyTwoFactorAsync(validatorContext.User, twoFactorOrganization, twoFactorProviderType, + twoFactorToken); + + // 3b. Response for 2FA required but request is not valid or remember token expired state. + if (!twoFactorTokenValid) + { + // The remember me token has expired. + if (twoFactorProviderType == TwoFactorProviderType.Remember) + { + var resultDict = await _twoFactorAuthenticationValidator + .BuildTwoFactorResultAsync(validatorContext.User, twoFactorOrganization); + + // Include Master Password Policy in 2FA response + resultDict.Add("MasterPasswordPolicy", await GetMasterPasswordPolicyAsync(validatorContext.User)); + SetTwoFactorResult(context, resultDict); + } + else + { + await SendFailedTwoFactorEmail(validatorContext.User, twoFactorProviderType); + await UpdateFailedAuthDetailsAsync(validatorContext.User); + await BuildErrorResultAsync("Two-step token is invalid. Try again.", true, context, + validatorContext.User); + } + + return false; + } + + // 3c. Given a valid token and a successful two-factor verification, if the provider type is Recovery Code, + // recovery will have been performed as part of 2FA validation. This will be relevant for, e.g., SSO users + // who are requesting recovery, but who will still need to log in after 2FA recovery. + if (twoFactorProviderType == TwoFactorProviderType.RecoveryCode) + { + validatorContext.TwoFactorRecoveryRequested = true; + } + + // 3d. When the 2FA authentication is successful, we can check if the user wants a + // rememberMe token. + var twoFactorRemember = request.Raw["TwoFactorRemember"] == "1"; + // Check if the user wants a rememberMe token. + if (twoFactorRemember + // if the 2FA auth was rememberMe do not send another token. + && twoFactorProviderType != TwoFactorProviderType.Remember) + { + validatorContext.RememberMeRequested = true; + } + + return true; + } + + /// + /// Validates whether the user is logging in from a known device. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateNewDeviceAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + var deviceValid = await _deviceValidator.ValidateRequestDeviceAsync(request, validatorContext); + if (deviceValid) + { + return true; + } + + SetValidationErrorResult(context, validatorContext); + await LogFailedLoginEvent(validatorContext.User, EventType.User_FailedLogIn); + return false; + } + + /// + /// Validates whether the user should be denied access on a given non-Web client and sent to the Web client + /// for Legacy migration. + /// + /// The current request context. + /// + /// + /// true if the scheme successfully passed validation, otherwise false. + private async Task ValidateLegacyMigrationAsync(T context, ValidatedTokenRequest request, + CustomValidatorRequestContext validatorContext) + { + if (!UserService.IsLegacyUser(validatorContext.User) || request.ClientId == "web") + { + return true; + } + + await FailAuthForLegacyUserAsync(validatorContext.User, context); + return false; + } + + /// + /// Validates and updates the auth request's timestamp. + /// + /// + /// true on evaluation and/or completed update of the AuthRequest. + private async Task ValidateAuthRequestAsync(CustomValidatorRequestContext validatorContext) + { + // TODO: PM-24324 - This should be its own validator at some point. + if (validatorContext.ValidatedAuthRequest != null) + { + validatorContext.ValidatedAuthRequest.AuthenticationDate = DateTime.UtcNow; + await _authRequestRepository.ReplaceAsync(validatorContext.ValidatedAuthRequest); + } + + return true; + } /// /// Responsible for building the response to the client when the user has successfully authenticated. @@ -251,7 +569,7 @@ public abstract class BaseRequestValidator where T : class /// used to associate the failed login with a user /// void [Obsolete("Consider using SetValidationErrorResult to set the validation result, and LogFailedLoginEvent " + - "to log the failure.")] + "to log the failure.")] protected async Task BuildErrorResultAsync(string message, bool twoFactorRequest, T context, User user) { if (user != null) @@ -263,8 +581,8 @@ public abstract class BaseRequestValidator where T : class if (_globalSettings.SelfHosted) { _logger.LogWarning(Constants.BypassFiltersEventId, - string.Format("Failed login attempt{0}{1}", twoFactorRequest ? ", 2FA invalid." : ".", - $" {CurrentContext.IpAddress}")); + "Failed login attempt. Is2FARequest: {Is2FARequest} IpAddress: {IpAddress}", twoFactorRequest, + CurrentContext.IpAddress); } await Task.Delay(2000); // Delay for brute force. @@ -288,21 +606,26 @@ public abstract class BaseRequestValidator where T : class formattedMessage = string.Format("Failed login attempt. {0}", $" {CurrentContext.IpAddress}"); break; case EventType.User_FailedLogIn2fa: - formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", $" {CurrentContext.IpAddress}"); + formattedMessage = string.Format("Failed login attempt, 2FA invalid.{0}", + $" {CurrentContext.IpAddress}"); break; default: formattedMessage = "Failed login attempt."; break; } - _logger.LogWarning(Constants.BypassFiltersEventId, formattedMessage); + + _logger.LogWarning(Constants.BypassFiltersEventId, "{FailedLoginMessage}", formattedMessage); } + await Task.Delay(2000); // Delay for brute force. } [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetSsoResult(T context, Dictionary customResponse); + [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -313,6 +636,7 @@ public abstract class BaseRequestValidator where T : class /// The current grant or token context /// The modified request context containing material used to build the response object protected abstract void SetValidationErrorResult(T context, CustomValidatorRequestContext requestContext); + protected abstract Task SetSuccessResult(T context, User user, List claims, Dictionary customResponse); @@ -339,7 +663,7 @@ public abstract class BaseRequestValidator where T : class // Check if user belongs to any organization with an active SSO policy var ssoRequired = FeatureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired + .SsoRequired : await PolicyService.AnyPoliciesApplicableToUserAsync( user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); if (ssoRequired) @@ -381,7 +705,8 @@ public abstract class BaseRequestValidator where T : class { if (FeatureService.IsEnabled(FeatureFlagKeys.FailedTwoFactorEmail)) { - await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, CurrentContext.IpAddress); + await _mailService.SendFailedTwoFactorAttemptEmailAsync(user.Email, failedAttemptType, DateTime.UtcNow, + CurrentContext.IpAddress); } } @@ -412,16 +737,14 @@ public abstract class BaseRequestValidator where T : class // We need this because we check for changes in the stamp to determine if we need to invalidate token refresh requests, // in the `ProfileService.IsActiveAsync` method. // If we don't store the security stamp in the persisted grant, we won't have the previous value to compare against. - var claims = new List - { - new Claim(Claims.SecurityStamp, user.SecurityStamp) - }; + var claims = new List { new Claim(Claims.SecurityStamp, user.SecurityStamp) }; if (device != null) { claims.Add(new Claim(Claims.Device, device.Identifier)); claims.Add(new Claim(Claims.DeviceType, device.Type.ToString())); } + return claims; } @@ -433,12 +756,15 @@ public abstract class BaseRequestValidator where T : class /// The current request context. /// The device used for authentication. /// Whether to send a 2FA remember token. - private async Task> BuildCustomResponse(User user, T context, Device device, bool sendRememberToken) + private async Task> BuildCustomResponse(User user, T context, Device device, + bool sendRememberToken) { var customResponse = new Dictionary(); if (!string.IsNullOrWhiteSpace(user.PrivateKey)) { customResponse.Add("PrivateKey", user.PrivateKey); + var accountKeys = await _accountKeysQuery.Run(user); + customResponse.Add("AccountKeys", new PrivateKeysResponseModel(accountKeys)); } if (!string.IsNullOrWhiteSpace(user.Key)) @@ -453,7 +779,8 @@ public abstract class BaseRequestValidator where T : class customResponse.Add("KdfIterations", user.KdfIterations); customResponse.Add("KdfMemory", user.KdfMemory); customResponse.Add("KdfParallelism", user.KdfParallelism); - customResponse.Add("UserDecryptionOptions", await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); + customResponse.Add("UserDecryptionOptions", + await CreateUserDecryptionOptionsAsync(user, device, GetSubject(context))); if (sendRememberToken) { @@ -461,6 +788,7 @@ public abstract class BaseRequestValidator where T : class CoreHelpers.CustomProviderName(TwoFactorProviderType.Remember)); customResponse.Add("TwoFactorToken", token); } + return customResponse; } @@ -468,7 +796,8 @@ public abstract class BaseRequestValidator where T : class /// /// Used to create a list of all possible ways the newly authenticated user can decrypt their vault contents /// - private async Task CreateUserDecryptionOptionsAsync(User user, Device device, ClaimsPrincipal subject) + private async Task CreateUserDecryptionOptionsAsync(User user, Device device, + ClaimsPrincipal subject) { var ssoConfig = await GetSsoConfigurationDataAsync(subject); return await UserDecryptionOptionsBuilder diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 1495973b80..64156ea5f3 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -8,6 +8,7 @@ using Bit.Core.Auth.Models.Api.Response; using Bit.Core.Auth.Repositories; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Platform.Installations; using Bit.Core.Repositories; using Bit.Core.Services; @@ -47,7 +48,8 @@ public class CustomTokenRequestValidator : BaseRequestValidator otpTokenProvider, IMailService mailService) : ISendAuthenticationMethodValidator { @@ -60,11 +62,20 @@ public class SendEmailOtpRequestValidator( { return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.OtpGenerationFailed); } - - await mailService.SendSendEmailOtpEmailAsync( - email, - token, - string.Format(SendAccessConstants.OtpEmail.Subject, token)); + if (featureService.IsEnabled(FeatureFlagKeys.MJMLBasedEmailTemplates)) + { + await mailService.SendSendEmailOtpEmailv2Async( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + } + else + { + await mailService.SendSendEmailOtpEmailAsync( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + } return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent); } diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index e679c48433..294df1c18d 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -12,6 +12,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.WebAuthnLogin; using Bit.Core.Context; using Bit.Core.Entities; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; @@ -50,7 +51,8 @@ public class WebAuthnGrantValidator : BaseRequestValidator endpoints.MapDefaultControllerRoute()); // Log startup - logger.LogInformation(Constants.BypassFiltersEventId, globalSettings.ProjectName + " started."); + logger.LogInformation(Constants.BypassFiltersEventId, "{Project} started.", globalSettings.ProjectName); } } diff --git a/src/Identity/Utilities/LoginApprovingClientTypes.cs b/src/Identity/Utilities/LoginApprovingClientTypes.cs index f0c7b831b7..28049ed16b 100644 --- a/src/Identity/Utilities/LoginApprovingClientTypes.cs +++ b/src/Identity/Utilities/LoginApprovingClientTypes.cs @@ -1,6 +1,4 @@ -using Bit.Core; -using Bit.Core.Enums; -using Bit.Core.Services; +using Bit.Core.Enums; namespace Bit.Identity.Utilities; @@ -11,28 +9,15 @@ public interface ILoginApprovingClientTypes public class LoginApprovingClientTypes : ILoginApprovingClientTypes { - public LoginApprovingClientTypes( - IFeatureService featureService) + public LoginApprovingClientTypes() { - if (featureService.IsEnabled(FeatureFlagKeys.BrowserExtensionLoginApproval)) + TypesThatCanApprove = new List { - TypesThatCanApprove = new List - { - ClientType.Desktop, - ClientType.Mobile, - ClientType.Web, - ClientType.Browser, - }; - } - else - { - TypesThatCanApprove = new List - { - ClientType.Desktop, - ClientType.Mobile, - ClientType.Web, - }; - } + ClientType.Desktop, + ClientType.Mobile, + ClientType.Web, + ClientType.Browser, + }; } public IReadOnlyCollection TypesThatCanApprove { get; } diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs index 5a743ba028..2be33e8846 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -218,6 +218,8 @@ public static class BulkResourceCreationService ciphersTable.Columns.Add(revisionDateColumn); var deletedDateColumn = new DataColumn(nameof(c.DeletedDate), typeof(DateTime)); ciphersTable.Columns.Add(deletedDateColumn); + var archivedDateColumn = new DataColumn(nameof(c.ArchivedDate), typeof(DateTime)); + ciphersTable.Columns.Add(archivedDateColumn); var repromptColumn = new DataColumn(nameof(c.Reprompt), typeof(short)); ciphersTable.Columns.Add(repromptColumn); var keyColummn = new DataColumn(nameof(c.Key), typeof(string)); @@ -247,6 +249,7 @@ public static class BulkResourceCreationService row[creationDateColumn] = cipher.CreationDate; row[revisionDateColumn] = cipher.RevisionDate; row[deletedDateColumn] = cipher.DeletedDate.HasValue ? (object)cipher.DeletedDate : DBNull.Value; + row[archivedDateColumn] = cipher.ArchivedDate.HasValue ? cipher.ArchivedDate : DBNull.Value; row[repromptColumn] = cipher.Reprompt.HasValue ? cipher.Reprompt.Value : DBNull.Value; row[keyColummn] = cipher.Key; diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs index b034f31f39..2ddc5679d5 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/EventRepository.cs @@ -230,6 +230,8 @@ public class EventRepository : Repository, IEventRepository eventsTable.Columns.Add(serviceAccountIdColumn); var projectIdColumn = new DataColumn(nameof(e.ProjectId), typeof(Guid)); eventsTable.Columns.Add(projectIdColumn); + var grantedServiceAccountIdColumn = new DataColumn(nameof(e.GrantedServiceAccountId), typeof(Guid)); + eventsTable.Columns.Add(grantedServiceAccountIdColumn); foreach (DataColumn col in eventsTable.Columns) { @@ -263,6 +265,7 @@ public class EventRepository : Repository, IEventRepository row[secretIdColumn] = ev.SecretId.HasValue ? ev.SecretId.Value : DBNull.Value; row[serviceAccountIdColumn] = ev.ServiceAccountId.HasValue ? ev.ServiceAccountId.Value : DBNull.Value; row[projectIdColumn] = ev.ProjectId.HasValue ? ev.ProjectId.Value : DBNull.Value; + row[grantedServiceAccountIdColumn] = ev.GrantedServiceAccountId.HasValue ? ev.GrantedServiceAccountId.Value : DBNull.Value; eventsTable.Rows.Add(row); } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs index ece9697a31..4f8fb979d3 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationIntegrationRepository.cs @@ -29,4 +29,17 @@ public class OrganizationIntegrationRepository : Repository GetByTeamsConfigurationTenantIdTeamId(string tenantId, string teamId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + var result = await connection.QuerySingleOrDefaultAsync( + "[dbo].[OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId]", + new { TenantId = tenantId, TeamId = teamId }, + commandType: CommandType.StoredProcedure); + + return result; + } + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index 5f389ae56d..ed5708844d 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -15,8 +15,6 @@ using Dapper; using Microsoft.Data.SqlClient; using Microsoft.Extensions.Logging; -#nullable enable - namespace Bit.Infrastructure.Dapper.Repositories; public class OrganizationUserRepository : Repository, IOrganizationUserRepository @@ -672,4 +670,21 @@ public class OrganizationUserRepository : Repository, IO }, commandType: CommandType.StoredProcedure); } + + public async Task ConfirmOrganizationUserAsync(OrganizationUser organizationUser) + { + await using var connection = new SqlConnection(_marsConnectionString); + + var rowCount = await connection.ExecuteScalarAsync( + $"[{Schema}].[OrganizationUser_ConfirmById]", + new + { + organizationUser.Id, + organizationUser.UserId, + RevisionDate = DateTime.UtcNow.Date, + Key = organizationUser.Key + }); + + return rowCount > 0; + } } diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs index 83d5ef6a70..865c4f8e5c 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/PolicyRepository.cs @@ -61,19 +61,6 @@ public class PolicyRepository : Repository, IPolicyRepository } } - public async Task> GetPolicyDetailsByUserId(Guid userId) - { - using (var connection = new SqlConnection(ConnectionString)) - { - var results = await connection.QueryAsync( - $"[{Schema}].[PolicyDetails_ReadByUserId]", - new { UserId = userId }, - commandType: CommandType.StoredProcedure); - - return results.ToList(); - } - } - public async Task> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable userIds, PolicyType type) { await using var connection = new SqlConnection(ConnectionString); diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index 35fc094973..445ff77109 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -71,6 +71,7 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs index 3d001cce92..c704a208d1 100644 --- a/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs +++ b/src/Infrastructure.Dapper/Dirt/OrganizationReportRepository.cs @@ -4,6 +4,7 @@ using System.Data; using Bit.Core.Dirt.Entities; using Bit.Core.Dirt.Models.Data; +using Bit.Core.Dirt.Reports.Models.Data; using Bit.Core.Dirt.Repositories; using Bit.Core.Settings; using Bit.Infrastructure.Dapper.Repositories; @@ -173,4 +174,31 @@ public class OrganizationReportRepository : Repository commandType: CommandType.StoredProcedure); } } + + public async Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using var connection = new SqlConnection(ConnectionString); + var parameters = new + { + Id = reportId, + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }; + + await connection.ExecuteAsync( + $"[{Schema}].[OrganizationReport_UpdateMetrics]", + parameters, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..5dcc2943b8 --- /dev/null +++ b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs @@ -0,0 +1,79 @@ +using System.Data; +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Bit.Infrastructure.Dapper.Repositories; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.KeyManagement.Repositories; + +public class UserSignatureKeyPairRepository : Repository, IUserSignatureKeyPairRepository +{ + public UserSignatureKeyPairRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { + } + + public UserSignatureKeyPairRepository(string connectionString, string readOnlyConnectionString) : base( + connectionString, readOnlyConnectionString) + { + } + + public async Task GetByUserIdAsync(Guid userId) + { + using (var connection = new SqlConnection(ConnectionString)) + { + return (await connection.QuerySingleOrDefaultAsync( + "[dbo].[UserSignatureKeyPair_ReadByUserId]", + new + { + UserId = userId + }, + commandType: CommandType.StoredProcedure))?.ToSignatureKeyPairData(); + } + } + + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signingKeys) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + await connection.QueryAsync( + "[dbo].[UserSignatureKeyPair_SetForRotation]", + new + { + Id = CoreHelpers.GenerateComb(), + UserId = userId, + SignatureAlgorithm = (byte)signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + }; + } + + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signingKeys) + { + return async (SqlConnection connection, SqlTransaction transaction) => + { + await connection.QueryAsync( + "[dbo].[UserSignatureKeyPair_UpdateForRotation]", + new + { + UserId = grantorId, + SignatureAlgorithm = (byte)signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + RevisionDate = DateTime.UtcNow + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + }; + } +} diff --git a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs index 4904574eee..48232ef484 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs @@ -13,7 +13,6 @@ using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; using Bit.Infrastructure.Dapper.AdminConsole.Helpers; using Bit.Infrastructure.Dapper.Repositories; -using Bit.Infrastructure.Dapper.Vault.Helpers; using Dapper; using Microsoft.Data.SqlClient; @@ -383,63 +382,6 @@ public class CipherRepository : Repository, ICipherRepository cmd.ExecuteNonQuery(); } - // Bulk copy data into temp table - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempCipher"; - var ciphersTable = ciphers.ToDataTable(); - foreach (DataColumn col in ciphersTable.Columns) - { - bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); - } - - ciphersTable.PrimaryKey = new DataColumn[] { ciphersTable.Columns[0] }; - await bulkCopy.WriteToServerAsync(ciphersTable); - } - - // Update cipher table from temp table - var sql = @" - UPDATE - [dbo].[Cipher] - SET - [Data] = TC.[Data], - [Attachments] = TC.[Attachments], - [RevisionDate] = TC.[RevisionDate], - [Key] = TC.[Key] - FROM - [dbo].[Cipher] C - INNER JOIN - #TempCipher TC ON C.Id = TC.Id - WHERE - C.[UserId] = @UserId - - DROP TABLE #TempCipher"; - - await using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; - cmd.ExecuteNonQuery(); - } - }; - } - - /// - public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( - Guid userId, IEnumerable ciphers) - { - return async (SqlConnection connection, SqlTransaction transaction) => - { - // Create temp table - var sqlCreateTemp = @" - SELECT TOP 0 * - INTO #TempCipher - FROM [dbo].[Cipher]"; - - await using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) - { - cmd.ExecuteNonQuery(); - } - // Bulk copy data into temp table await BulkResourceCreationService.CreateTempCiphersAsync(connection, transaction, ciphers); @@ -476,88 +418,6 @@ public class CipherRepository : Repository, ICipherRepository return; } - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) - { - try - { - // 1. Create temp tables to bulk copy into. - - var sqlCreateTemp = @" - SELECT TOP 0 * - INTO #TempCipher - FROM [dbo].[Cipher]"; - - using (var cmd = new SqlCommand(sqlCreateTemp, connection, transaction)) - { - cmd.ExecuteNonQuery(); - } - - // 2. Bulk copy into temp tables. - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "#TempCipher"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - // 3. Insert into real tables from temp tables and clean up. - - // Intentionally not including Favorites, Folders, and CreationDate - // since those are not meant to be bulk updated at this time - var sql = @" - UPDATE - [dbo].[Cipher] - SET - [UserId] = TC.[UserId], - [OrganizationId] = TC.[OrganizationId], - [Type] = TC.[Type], - [Data] = TC.[Data], - [Attachments] = TC.[Attachments], - [RevisionDate] = TC.[RevisionDate], - [DeletedDate] = TC.[DeletedDate], - [Key] = TC.[Key] - FROM - [dbo].[Cipher] C - INNER JOIN - #TempCipher TC ON C.Id = TC.Id - WHERE - C.[UserId] = @UserId - - DROP TABLE #TempCipher"; - - using (var cmd = new SqlCommand(sql, connection, transaction)) - { - cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; - cmd.ExecuteNonQuery(); - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = userId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } - } - } - } - - public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) - { - if (!ciphers.Any()) - { - return; - } - using (var connection = new SqlConnection(ConnectionString)) { connection.Open(); @@ -635,54 +495,6 @@ public class CipherRepository : Repository, ICipherRepository return; } - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) - { - try - { - if (folders.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Folder]"; - var dataTable = BuildFoldersTable(bulkCopy, folders); - bulkCopy.WriteToServer(dataTable); - } - } - - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDate]", - new { Id = userId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } - } - } - } - - public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, IEnumerable folders) - { - if (!ciphers.Any()) - { - return; - } - using (var connection = new SqlConnection(ConnectionString)) { connection.Open(); @@ -722,75 +534,6 @@ public class CipherRepository : Repository, ICipherRepository return; } - using (var connection = new SqlConnection(ConnectionString)) - { - connection.Open(); - - using (var transaction = connection.BeginTransaction()) - { - try - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Cipher]"; - var dataTable = BuildCiphersTable(bulkCopy, ciphers); - bulkCopy.WriteToServer(dataTable); - } - - if (collections.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[Collection]"; - var dataTable = BuildCollectionsTable(bulkCopy, collections); - bulkCopy.WriteToServer(dataTable); - } - } - - if (collectionCiphers.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[CollectionCipher]"; - var dataTable = BuildCollectionCiphersTable(bulkCopy, collectionCiphers); - bulkCopy.WriteToServer(dataTable); - } - } - - if (collectionUsers.Any()) - { - using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) - { - bulkCopy.DestinationTableName = "[dbo].[CollectionUser]"; - var dataTable = BuildCollectionUsersTable(bulkCopy, collectionUsers); - bulkCopy.WriteToServer(dataTable); - } - } - - await connection.ExecuteAsync( - $"[{Schema}].[User_BumpAccountRevisionDateByOrganizationId]", - new { OrganizationId = ciphers.First().OrganizationId }, - commandType: CommandType.StoredProcedure, transaction: transaction); - - transaction.Commit(); - } - catch - { - transaction.Rollback(); - throw; - } - } - } - } - - public async Task CreateAsync_vNext(IEnumerable ciphers, IEnumerable collections, - IEnumerable collectionCiphers, IEnumerable collectionUsers) - { - if (!ciphers.Any()) - { - return; - } - using (var connection = new SqlConnection(ConnectionString)) { connection.Open(); diff --git a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs index 292e99d6ad..869321f280 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs @@ -85,4 +85,19 @@ public class SecurityTaskRepository : Repository, ISecurityT return tasksList; } + + /// + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + await using var connection = new SqlConnection(ConnectionString); + await connection.ExecuteAsync( + $"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]", + new { CipherIds = cipherIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs index 76e9b2e912..98f10394f4 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Configurations/EventEntityTypeConfiguration.cs @@ -12,9 +12,16 @@ public class EventEntityTypeConfiguration : IEntityTypeConfiguration .Property(e => e.Id) .ValueGeneratedNever(); - builder - .HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId }) - .IsClustered(false); + builder.HasKey(e => e.Id) + .IsClustered(); + + var index = builder.HasIndex(e => new { e.Date, e.OrganizationId, e.ActingUserId, e.CipherId }) + .IsClustered(false) + .HasDatabaseName("IX_Event_DateOrganizationIdUserId"); + + SqlServerIndexBuilderExtensions.IncludeProperties( + index, + e => new { e.ServiceAccountId, e.GrantedServiceAccountId }); builder.ToTable(nameof(Event)); } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs index 5670b2ae9b..c11591efcd 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationIntegrationRepository.cs @@ -26,4 +26,16 @@ public class OrganizationIntegrationRepository : return await query.Run(dbContext).ToListAsync(); } } + + public async Task GetByTeamsConfigurationTenantIdTeamId( + string tenantId, + string teamId) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + var query = new OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery(tenantId: tenantId, teamId: teamId); + return await query.Run(dbContext).SingleOrDefaultAsync(); + } + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 200c4aa308..ebc2bc6606 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -112,7 +112,8 @@ public class OrganizationRepository : Repository GetOccupiedSeatCountByOrganizationIdAsync(Guid organizationId) { using (var scope = ServiceScopeFactory.CreateScope()) diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index fae0598c1c..e5016a20d4 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -942,4 +942,25 @@ public class OrganizationUserRepository : Repository ConfirmOrganizationUserAsync(Core.Entities.OrganizationUser organizationUser) + { + using var scope = ServiceScopeFactory.CreateScope(); + await using var dbContext = GetDatabaseContext(scope); + + var result = await dbContext.OrganizationUsers + .Where(ou => ou.Id == organizationUser.Id && ou.Status == OrganizationUserStatusType.Accepted) + .ExecuteUpdateAsync(x => x + .SetProperty(y => y.Status, OrganizationUserStatusType.Confirmed) + .SetProperty(y => y.Key, organizationUser.Key)); + + if (result <= 0) + { + return false; + } + + await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdAsync(organizationUser.Id); + return true; + + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs index 72c277f1d7..1cca7a9bbb 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/PolicyRepository.cs @@ -56,45 +56,6 @@ public class PolicyRepository : Repository> GetPolicyDetailsByUserId(Guid userId) - { - using var scope = ServiceScopeFactory.CreateScope(); - var dbContext = GetDatabaseContext(scope); - - var providerOrganizations = from pu in dbContext.ProviderUsers - where pu.UserId == userId - join po in dbContext.ProviderOrganizations - on pu.ProviderId equals po.ProviderId - select po; - - var query = from p in dbContext.Policies - join ou in dbContext.OrganizationUsers - on p.OrganizationId equals ou.OrganizationId - join o in dbContext.Organizations - on p.OrganizationId equals o.Id - where - p.Enabled && - o.Enabled && - o.UsePolicies && - ( - (ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) || - // Invited orgUsers do not have a UserId associated with them, so we have to match up their email - (ou.Status == OrganizationUserStatusType.Invited && ou.Email == dbContext.Users.Find(userId).Email) - ) - select new PolicyDetails - { - OrganizationUserId = ou.Id, - OrganizationId = p.OrganizationId, - PolicyType = p.Type, - PolicyData = p.Data, - OrganizationUserType = ou.Type, - OrganizationUserStatus = ou.Status, - OrganizationUserPermissionsData = ou.Permissions, - IsProvider = providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId) - }; - return await query.ToListAsync(); - } - public async Task> GetPolicyDetailsByOrganizationIdAsync(Guid organizationId, PolicyType policyType) { using var scope = ServiceScopeFactory.CreateScope(); diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs index 01f3a1fe14..72dc8db386 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByOrganizationIdServiceAccountIdQuery.cs @@ -30,7 +30,7 @@ public class EventReadPageByOrganizationIdServiceAccountIdQuery : IQuery (_beforeDate != null || e.Date <= _endDate) && (_beforeDate == null || e.Date < _beforeDate.Value) && e.OrganizationId == _organizationId && - e.ServiceAccountId == _serviceAccountId + (e.ServiceAccountId == _serviceAccountId || e.GrantedServiceAccountId == _serviceAccountId) orderby e.Date descending select e; return q.Skip(0).Take(_pageOptions.PageSize); diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs new file mode 100644 index 0000000000..0d1cd6a656 --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/EventReadPageByServiceAccountIdQuery.cs @@ -0,0 +1,48 @@ +using Bit.Core.Models.Data; +using Bit.Core.SecretsManager.Entities; +using Event = Bit.Infrastructure.EntityFramework.Models.Event; + +namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; + +public class EventReadPageByServiceAccountQuery : IQuery +{ + private readonly ServiceAccount _serviceAccount; + private readonly DateTime _startDate; + private readonly DateTime _endDate; + private readonly DateTime? _beforeDate; + private readonly PageOptions _pageOptions; + + public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, PageOptions pageOptions) + { + _serviceAccount = serviceAccount; + _startDate = startDate; + _endDate = endDate; + _beforeDate = null; + _pageOptions = pageOptions; + } + + public EventReadPageByServiceAccountQuery(ServiceAccount serviceAccount, DateTime startDate, DateTime endDate, DateTime? beforeDate, PageOptions pageOptions) + { + _serviceAccount = serviceAccount; + _startDate = startDate; + _endDate = endDate; + _beforeDate = beforeDate; + _pageOptions = pageOptions; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var q = from e in dbContext.Events + where e.Date >= _startDate && + (_beforeDate == null || e.Date < _beforeDate.Value) && + ( + (_serviceAccount.OrganizationId == Guid.Empty && !e.OrganizationId.HasValue) || + (_serviceAccount.OrganizationId != Guid.Empty && e.OrganizationId == _serviceAccount.OrganizationId) + ) && + e.GrantedServiceAccountId == _serviceAccount.Id + orderby e.Date descending + select e; + + return q.Take(_pageOptions.PageSize); + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs new file mode 100644 index 0000000000..a1e86d9add --- /dev/null +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery.cs @@ -0,0 +1,36 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Enums; +using Bit.Infrastructure.EntityFramework.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories.Queries; + +namespace Bit.Infrastructure.EntityFramework.AdminConsole.Repositories.Queries; + +public class OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery : IQuery +{ + private readonly string _tenantId; + private readonly string _teamId; + + public OrganizationIntegrationReadByTeamsConfigurationTenantIdTeamIdQuery(string tenantId, string teamId) + { + _tenantId = tenantId; + _teamId = teamId; + } + + public IQueryable Run(DatabaseContext dbContext) + { + var query = + from oi in dbContext.OrganizationIntegrations + where oi.Type == IntegrationType.Teams && + oi.Configuration != null && + oi.Configuration.Contains($"\"TenantId\":\"{_tenantId}\"") && + oi.Configuration.Contains($"\"id\":\"{_teamId}\"") + select new OrganizationIntegration() + { + Id = oi.Id, + OrganizationId = oi.OrganizationId, + Type = oi.Type, + Configuration = oi.Configuration, + }; + return query; + } +} diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs index 26d3a128fc..504a75c9f2 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserOrganizationDetailsViewQuery.cs @@ -73,7 +73,8 @@ public class OrganizationUserOrganizationDetailsViewQuery : IQuery new ProviderUserOrganizationDetails { OrganizationId = x.po.OrganizationId, @@ -29,6 +31,9 @@ public class ProviderUserOrganizationDetailsViewQuery : IQuery(updatedReport); } } + + public Task UpdateMetricsAsync(Guid reportId, OrganizationReportMetricsData metrics) + { + using (var scope = ServiceScopeFactory.CreateScope()) + { + var dbContext = GetDatabaseContext(scope); + + return dbContext.OrganizationReports + .Where(p => p.Id == reportId) + .UpdateAsync(p => new Models.OrganizationReport + { + ApplicationCount = metrics.ApplicationCount, + ApplicationAtRiskCount = metrics.ApplicationAtRiskCount, + CriticalApplicationCount = metrics.CriticalApplicationCount, + CriticalApplicationAtRiskCount = metrics.CriticalApplicationAtRiskCount, + MemberCount = metrics.MemberCount, + MemberAtRiskCount = metrics.MemberAtRiskCount, + CriticalMemberCount = metrics.CriticalMemberCount, + CriticalMemberAtRiskCount = metrics.CriticalMemberAtRiskCount, + PasswordCount = metrics.PasswordCount, + PasswordAtRiskCount = metrics.PasswordAtRiskCount, + CriticalPasswordCount = metrics.CriticalPasswordCount, + CriticalPasswordAtRiskCount = metrics.CriticalPasswordAtRiskCount, + RevisionDate = DateTime.UtcNow + }); + } + } } diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 7a6507230e..3c35df2a82 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -108,6 +108,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs new file mode 100644 index 0000000000..aa10a73a88 --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Configurations/UserSignatureKeyPairEntityTypeConfiguration.cs @@ -0,0 +1,22 @@ +using Bit.Infrastructure.EntityFramework.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata.Builders; + +namespace Bit.Infrastructure.EntityFramework.Configurations; + +public class UserSignatureKeyPairEntityTypeConfiguration : IEntityTypeConfiguration +{ + public void Configure(EntityTypeBuilder builder) + { + builder + .Property(s => s.Id) + .ValueGeneratedNever(); + + builder + .HasIndex(s => s.UserId) + .IsUnique() + .IsClustered(false); + + builder.ToTable(nameof(UserSignatureKeyPair)); + } +} diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs b/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs new file mode 100644 index 0000000000..b2bd8a1345 --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Models/UserSignatureKeyPair.cs @@ -0,0 +1,19 @@ +// FIXME: Update this file to be null safe and then delete the line below +#nullable disable + +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.Models; + +public class UserSignatureKeyPair : Core.KeyManagement.Entities.UserSignatureKeyPair +{ + public virtual User User { get; set; } +} + +public class UserSignatureKeyPairMapperProfile : Profile +{ + public UserSignatureKeyPairMapperProfile() + { + CreateMap().ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs new file mode 100644 index 0000000000..04f055501d --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserSignatureKeyPairRepository.cs @@ -0,0 +1,66 @@ + +using AutoMapper; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Core.KeyManagement.UserKey; +using Bit.Core.Utilities; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.KeyManagement.Repositories; + +public class UserSignatureKeyPairRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : Repository(serviceScopeFactory, mapper, context => context.UserSignatureKeyPairs), IUserSignatureKeyPairRepository +{ + public async Task GetByUserIdAsync(Guid userId) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var signingKeys = await dbContext.UserSignatureKeyPairs.FirstOrDefaultAsync(x => x.UserId == userId); + if (signingKeys == null) + { + return null; + } + + return signingKeys.ToSignatureKeyPairData(); + } + + public UpdateEncryptedDataForKeyRotation SetUserSignatureKeyPair(Guid userId, SignatureKeyPairData signingKeys) + { + return async (_, _) => + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var entity = new Models.UserSignatureKeyPair + { + Id = CoreHelpers.GenerateComb(), + UserId = userId, + SignatureAlgorithm = signingKeys.SignatureAlgorithm, + SigningKey = signingKeys.WrappedSigningKey, + VerifyingKey = signingKeys.VerifyingKey, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + }; + await dbContext.UserSignatureKeyPairs.AddAsync(entity); + await dbContext.SaveChangesAsync(); + }; + } + + public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, SignatureKeyPairData signingKeys) + { + return async (_, _) => + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + var entity = await dbContext.UserSignatureKeyPairs.FirstOrDefaultAsync(x => x.UserId == grantorId); + if (entity != null) + { + entity.SignatureAlgorithm = signingKeys.SignatureAlgorithm; + entity.SigningKey = signingKeys.WrappedSigningKey; + entity.VerifyingKey = signingKeys.VerifyingKey; + entity.RevisionDate = DateTime.UtcNow; + await dbContext.SaveChangesAsync(); + } + }; + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs index 6e2805f987..39e3ab8019 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionCipherRepository.cs @@ -1,4 +1,5 @@ using AutoMapper; +using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; @@ -145,9 +146,11 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); - var availableCollections = await (from c in dbContext.Collections - where c.OrganizationId == organizationId - select c).ToListAsync(); + + var availableCollectionIds = await (from c in dbContext.Collections + where c.OrganizationId == organizationId + && c.Type != CollectionType.DefaultUserCollection + select c.Id).ToListAsync(); var currentCollectionCiphers = await (from cc in dbContext.CollectionCiphers where cc.CipherId == cipherId @@ -155,6 +158,8 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec foreach (var requestedCollectionId in collectionIds) { + if (!availableCollectionIds.Contains(requestedCollectionId)) continue; + var requestedCollectionCipher = currentCollectionCiphers .FirstOrDefault(cc => cc.CollectionId == requestedCollectionId); @@ -168,7 +173,7 @@ public class CollectionCipherRepository : BaseEntityFrameworkRepository, ICollec } } - dbContext.RemoveRange(currentCollectionCiphers.Where(cc => !collectionIds.Contains(cc.CollectionId))); + dbContext.RemoveRange(currentCollectionCiphers.Where(cc => availableCollectionIds.Contains(cc.CollectionId) && !collectionIds.Contains(cc.CollectionId))); await dbContext.UserBumpAccountRevisionDateByOrganizationIdAsync(organizationId); await dbContext.SaveChangesAsync(); } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 7446abdd97..b748a26db2 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -63,6 +63,7 @@ public class DatabaseContext : DbContext public DbSet Policies { get; set; } public DbSet Providers { get; set; } public DbSet Secret { get; set; } + public DbSet SecretVersion { get; set; } public DbSet ServiceAccount { get; set; } public DbSet Project { get; set; } public DbSet ProviderUsers { get; set; } @@ -73,6 +74,7 @@ public class DatabaseContext : DbContext public DbSet TaxRates { get; set; } public DbSet Transactions { get; set; } public DbSet Users { get; set; } + public DbSet UserSignatureKeyPairs { get; set; } public DbSet AuthRequests { get; set; } public DbSet OrganizationDomains { get; set; } public DbSet WebAuthnCredentials { get; set; } diff --git a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs index bd70e27e78..809704edb7 100644 --- a/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/UserRepository.cs @@ -283,6 +283,9 @@ public class UserRepository : Repository, IUserR var transaction = await dbContext.Database.BeginTransactionAsync(); + MigrateDefaultUserCollectionsToShared(dbContext, [user.Id]); + await dbContext.SaveChangesAsync(); + dbContext.WebAuthnCredentials.RemoveRange(dbContext.WebAuthnCredentials.Where(w => w.UserId == user.Id)); dbContext.Ciphers.RemoveRange(dbContext.Ciphers.Where(c => c.UserId == user.Id)); dbContext.Folders.RemoveRange(dbContext.Folders.Where(f => f.UserId == user.Id)); @@ -314,8 +317,8 @@ public class UserRepository : Repository, IUserR var mappedUser = Mapper.Map(user); dbContext.Users.Remove(mappedUser); - await transaction.CommitAsync(); await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); } } @@ -329,21 +332,30 @@ public class UserRepository : Repository, IUserR var targetIds = users.Select(u => u.Id).ToList(); + MigrateDefaultUserCollectionsToShared(dbContext, targetIds); + await dbContext.SaveChangesAsync(); + await dbContext.WebAuthnCredentials.Where(wa => targetIds.Contains(wa.UserId)).ExecuteDeleteAsync(); await dbContext.Ciphers.Where(c => targetIds.Contains(c.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.Folders.Where(f => targetIds.Contains(f.UserId)).ExecuteDeleteAsync(); await dbContext.AuthRequests.Where(a => targetIds.Contains(a.UserId)).ExecuteDeleteAsync(); await dbContext.Devices.Where(d => targetIds.Contains(d.UserId)).ExecuteDeleteAsync(); - var collectionUsers = from cu in dbContext.CollectionUsers - join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id - where targetIds.Contains(ou.UserId ?? default) - select cu; - dbContext.CollectionUsers.RemoveRange(collectionUsers); - var groupUsers = from gu in dbContext.GroupUsers - join ou in dbContext.OrganizationUsers on gu.OrganizationUserId equals ou.Id - where targetIds.Contains(ou.UserId ?? default) - select gu; - dbContext.GroupUsers.RemoveRange(groupUsers); + await dbContext.CollectionUsers + .Join(dbContext.OrganizationUsers, + cu => cu.OrganizationUserId, + ou => ou.Id, + (cu, ou) => new { CollectionUser = cu, OrganizationUser = ou }) + .Where((joined) => targetIds.Contains(joined.OrganizationUser.UserId ?? default)) + .Select(joined => joined.CollectionUser) + .ExecuteDeleteAsync(); + await dbContext.GroupUsers + .Join(dbContext.OrganizationUsers, + gu => gu.OrganizationUserId, + ou => ou.Id, + (gu, ou) => new { GroupUser = gu, OrganizationUser = ou }) + .Where(joined => targetIds.Contains(joined.OrganizationUser.UserId ?? default)) + .Select(joined => joined.GroupUser) + .ExecuteDeleteAsync(); await dbContext.UserProjectAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.UserServiceAccountAccessPolicy.Where(ap => targetIds.Contains(ap.OrganizationUser.UserId ?? default)).ExecuteDeleteAsync(); await dbContext.OrganizationUsers.Where(ou => targetIds.Contains(ou.UserId ?? default)).ExecuteDeleteAsync(); @@ -354,15 +366,29 @@ public class UserRepository : Repository, IUserR await dbContext.NotificationStatuses.Where(ns => targetIds.Contains(ns.UserId)).ExecuteDeleteAsync(); await dbContext.Notifications.Where(n => targetIds.Contains(n.UserId ?? default)).ExecuteDeleteAsync(); - foreach (var u in users) - { - var mappedUser = Mapper.Map(u); - dbContext.Users.Remove(mappedUser); - } + await dbContext.Users.Where(u => targetIds.Contains(u.Id)).ExecuteDeleteAsync(); - - await transaction.CommitAsync(); await dbContext.SaveChangesAsync(); + await transaction.CommitAsync(); + } + } + + private static void MigrateDefaultUserCollectionsToShared(DatabaseContext dbContext, IEnumerable userIds) + { + var defaultCollections = (from c in dbContext.Collections + join cu in dbContext.CollectionUsers on c.Id equals cu.CollectionId + join ou in dbContext.OrganizationUsers on cu.OrganizationUserId equals ou.Id + join u in dbContext.Users on ou.UserId equals u.Id + where userIds.Contains(ou.UserId!.Value) + && c.Type == Core.Enums.CollectionType.DefaultUserCollection + select new { Collection = c, UserEmail = u.Email }) + .ToList(); + + foreach (var item in defaultCollections) + { + item.Collection.Type = Core.Enums.CollectionType.SharedCollection; + item.Collection.DefaultUserCollectionEmail = item.Collection.DefaultUserCollectionEmail ?? item.UserEmail; + item.Collection.RevisionDate = DateTime.UtcNow; } } } diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs new file mode 100644 index 0000000000..069c7e2450 --- /dev/null +++ b/src/Infrastructure.EntityFramework/SecretsManager/Configurations/SecretVersionEntityTypeConfiguration.cs @@ -0,0 +1,42 @@ +using Bit.Infrastructure.EntityFramework.SecretsManager.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata.Builders; + +namespace Bit.Infrastructure.EntityFramework.SecretsManager.Configurations; + +public class SecretVersionEntityTypeConfiguration : IEntityTypeConfiguration +{ + public void Configure(EntityTypeBuilder builder) + { + builder.Property(sv => sv.Id) + .ValueGeneratedNever(); + + builder.HasKey(sv => sv.Id) + .IsClustered(); + + builder.Property(sv => sv.Value) + .IsRequired(); + + builder.Property(sv => sv.VersionDate) + .IsRequired(); + + builder.HasOne(sv => sv.EditorServiceAccount) + .WithMany() + .HasForeignKey(sv => sv.EditorServiceAccountId) + .OnDelete(DeleteBehavior.SetNull); + + builder.HasOne(sv => sv.EditorOrganizationUser) + .WithMany() + .HasForeignKey(sv => sv.EditorOrganizationUserId) + .OnDelete(DeleteBehavior.SetNull); + + builder.HasIndex(sv => sv.SecretId) + .HasDatabaseName("IX_SecretVersion_SecretId"); + + builder.HasIndex(sv => sv.EditorServiceAccountId) + .HasDatabaseName("IX_SecretVersion_EditorServiceAccountId"); + + builder.HasIndex(sv => sv.EditorOrganizationUserId) + .HasDatabaseName("IX_SecretVersion_EditorOrganizationUserId"); + } +} diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs b/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs index 5992f32135..09d8c389df 100644 --- a/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs +++ b/src/Infrastructure.EntityFramework/SecretsManager/Models/Secret.cs @@ -13,6 +13,7 @@ public class Secret : Core.SecretsManager.Entities.Secret public virtual ICollection UserAccessPolicies { get; set; } public virtual ICollection GroupAccessPolicies { get; set; } public virtual ICollection ServiceAccountAccessPolicies { get; set; } + public virtual ICollection SecretVersions { get; set; } } public class SecretMapperProfile : Profile diff --git a/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs b/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs new file mode 100644 index 0000000000..d4a364ab0f --- /dev/null +++ b/src/Infrastructure.EntityFramework/SecretsManager/Models/SecretVersion.cs @@ -0,0 +1,24 @@ +#nullable enable + +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.SecretsManager.Models; + +public class SecretVersion : Core.SecretsManager.Entities.SecretVersion +{ + public Secret? Secret { get; set; } + + public ServiceAccount? EditorServiceAccount { get; set; } + + public Bit.Infrastructure.EntityFramework.Models.OrganizationUser? EditorOrganizationUser { get; set; } +} + +public class SecretVersionMapperProfile : Profile +{ + public SecretVersionMapperProfile() + { + CreateMap() + .PreserveReferences() + .ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs index d88f0e98bb..3c45afe530 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/CipherRepository.cs @@ -168,16 +168,6 @@ public class CipherRepository : Repository - /// - /// EF does not use the bulk resource creation service, so we need to use the regular create method. - /// - public async Task CreateAsync_vNext(Guid userId, IEnumerable ciphers, - IEnumerable folders) - { - await CreateAsync(userId, ciphers, folders); - } - public async Task CreateAsync(IEnumerable ciphers, IEnumerable collections, IEnumerable collectionCiphers, @@ -216,18 +206,6 @@ public class CipherRepository : Repository - /// - /// EF does not use the bulk resource creation service, so we need to use the regular create method. - /// - public async Task CreateAsync_vNext(IEnumerable ciphers, - IEnumerable collections, - IEnumerable collectionCiphers, - IEnumerable collectionUsers) - { - await CreateAsync(ciphers, collections, collectionCiphers, collectionUsers); - } - public async Task DeleteAsync(IEnumerable ids, Guid userId) { await ToggleDeleteCipherStatesAsync(ids, userId, CipherStateAction.HardDelete); @@ -986,15 +964,6 @@ public class CipherRepository : Repository - /// - /// EF does not use the bulk resource creation service, so we need to use the regular update method. - /// - public async Task UpdateCiphersAsync_vNext(Guid userId, IEnumerable ciphers) - { - await UpdateCiphersAsync(userId, ciphers); - } - public async Task UpdatePartialAsync(Guid id, Guid userId, Guid? folderId, bool favorite) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -1107,16 +1076,6 @@ public class CipherRepository : Repository - /// - /// EF does not use the bulk resource creation service, so we need to use the regular update method. - /// - public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation_vNext( - Guid userId, IEnumerable ciphers) - { - return UpdateForKeyRotation(userId, ciphers); - } - public async Task UpsertAsync(CipherDetails cipher) { if (cipher.Id.Equals(default)) diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs index d4f9424d40..9967f18a3e 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs @@ -96,4 +96,24 @@ public class SecurityTaskRepository : Repository + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var cipherIdsList = cipherIds.ToList(); + + await dbContext.SecurityTasks + .Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed) + .ExecuteUpdateAsync(st => st + .SetProperty(s => s.Status, SecurityTaskStatus.Completed) + .SetProperty(s => s.RevisionDate, DateTime.UtcNow)); + } } diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index 94aa14eaf6..40dd8d22d4 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -1,34 +1,26 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Azure.Storage.Queues; +using Azure.Storage.Queues; using Bit.Core.Settings; using Bit.Core.Utilities; -using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; + private readonly HubHelpers _hubHelpers; private readonly GlobalSettings _globalSettings; - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; + private Task? _executingTask; + private CancellationTokenSource? _cts; public AzureQueueHostedService( ILogger logger, - IHubContext hubContext, - IHubContext anonymousHubContext, + HubHelpers hubHelpers, GlobalSettings globalSettings) { _logger = logger; - _hubContext = hubContext; + _hubHelpers = hubHelpers; _globalSettings = globalSettings; - _anonymousHubContext = anonymousHubContext; } public Task StartAsync(CancellationToken cancellationToken) @@ -44,32 +36,39 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } + _logger.LogWarning("Stopping service."); - _cts.Cancel(); + _cts?.Cancel(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } public void Dispose() - { } + { + } private async Task ExecuteAsync(CancellationToken cancellationToken) { - _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); + var queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); while (!cancellationToken.IsCancellationRequested) { try { - var messages = await _queueClient.ReceiveMessagesAsync(32); + var messages = await queueClient.ReceiveMessagesAsync(32, cancellationToken: cancellationToken); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { try { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, _anonymousHubContext, _logger, cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + var decodedMessage = message.DecodeMessageText(); + if (!string.IsNullOrWhiteSpace(decodedMessage)) + { + await _hubHelpers.SendNotificationToHubAsync(decodedMessage, cancellationToken); + } + + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } catch (Exception e) { @@ -77,7 +76,8 @@ public class AzureQueueHostedService : IHostedService, IDisposable message.MessageId, message.DequeueCount); if (message.DequeueCount > 2) { - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 7debd51df7..c663102b56 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -1,36 +1,30 @@ -using System.Text; +#nullable enable +using System.Text; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; +namespace Bit.Notifications.Controllers; [Authorize("Internal")] public class SendController : Controller { - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; - private readonly ILogger _logger; + private readonly HubHelpers _hubHelpers; - public SendController(IHubContext hubContext, IHubContext anonymousHubContext, ILogger logger) + public SendController(HubHelpers hubHelpers) { - _hubContext = hubContext; - _anonymousHubContext = anonymousHubContext; - _logger = logger; + _hubHelpers = hubHelpers; } [HttpPost("~/send")] [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() + public async Task PostSendAsync() { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + using var reader = new StreamReader(Request.Body, Encoding.UTF8); + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) - { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext, _anonymousHubContext, _logger); - } + await _hubHelpers.SendNotificationToHubAsync(notificationJson); } } } diff --git a/src/Notifications/Dockerfile b/src/Notifications/Dockerfile index 1b0b507606..031df0b1b6 100644 --- a/src/Notifications/Dockerfile +++ b/src/Notifications/Dockerfile @@ -57,6 +57,8 @@ RUN apk add --no-cache curl \ WORKDIR /app COPY --from=build /source/src/Notifications/out /app COPY ./src/Notifications/entrypoint.sh /entrypoint.sh +RUN echo "net.ipv4.ip_local_port_range = 5024 65000" >> /etc/sysctl.d/99-sysctl.conf +RUN echo "net.ipv4.tcp_fin_timeout = 30" >> /etc/sysctl.d/99-sysctl.conf RUN chmod +x /entrypoint.sh HEALTHCHECK CMD curl -f http://localhost:5000/alive || exit 1 diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 69d5bdc958..b0dec8b415 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -1,31 +1,39 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; +using System.Text.Json; using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; -public static class HubHelpers +public class HubHelpers { - private static JsonSerializerOptions _deserializerOptions = - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly JsonSerializerOptions _deserializerOptions = new() { PropertyNameCaseInsensitive = true }; private static readonly string _receiveMessageMethod = "ReceiveMessage"; - public static async Task SendNotificationToHubAsync( - string notificationJson, - IHubContext hubContext, + private readonly IHubContext _hubContext; + private readonly IHubContext _anonymousHubContext; + private readonly ILogger _logger; + + public HubHelpers(IHubContext hubContext, IHubContext anonymousHubContext, - ILogger logger, - CancellationToken cancellationToken = default(CancellationToken) - ) + ILogger logger) + { + _hubContext = hubContext; + _anonymousHubContext = anonymousHubContext; + _logger = logger; + } + + public async Task SendNotificationToHubAsync(string notificationJson, CancellationToken cancellationToken = default) { var notification = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - logger.LogInformation("Sending notification: {NotificationType}", notification.Type); + if (notification is null) + { + return; + } + + _logger.LogInformation("Sending notification: {NotificationType}", notification.Type); switch (notification.Type) { case PushType.SyncCipherUpdate: @@ -35,14 +43,19 @@ public static class HubHelpers var cipherNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (cipherNotification is null) + { + break; + } + if (cipherNotification.Payload.UserId.HasValue) { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(cipherNotification.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients + await _hubContext.Clients .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } @@ -54,7 +67,12 @@ public static class HubHelpers var folderNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + if (folderNotification is null) + { + break; + } + + await _hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: @@ -64,9 +82,14 @@ public static class HubHelpers case PushType.SyncSettings: case PushType.LogOut: var userNotification = - JsonSerializer.Deserialize>( + JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + if (userNotification is null) + { + break; + } + + await _hubContext.Clients.User(userNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: @@ -75,58 +98,102 @@ public static class HubHelpers var sendNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + if (sendNotification is null) + { + break; + } + + await _hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) + if (authRequestResponseNotification is null) + { + break; + } + + await _anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) .SendAsync("AuthRequestResponseRecieved", authRequestResponseNotification, cancellationToken); break; case PushType.AuthRequest: var authRequestNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) + if (authRequestNotification is null) + { + break; + } + + await _hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncOrganizationStatusChanged: var orgStatusNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) + if (orgStatusNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) .SendAsync(_receiveMessageMethod, orgStatusNotification, cancellationToken); break; case PushType.SyncOrganizationCollectionSettingChanged: var organizationCollectionSettingsChangedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification.Payload.OrganizationId)) - .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, cancellationToken); + if (organizationCollectionSettingsChangedNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification + .Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, + cancellationToken); break; case PushType.OrganizationBankAccountVerified: var organizationBankAccountVerifiedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationBankAccountVerifiedNotification.Payload.OrganizationId)) + if (organizationBankAccountVerifiedNotification is null) + { + break; + } + + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationBankAccountVerifiedNotification.Payload.OrganizationId)) .SendAsync(_receiveMessageMethod, organizationBankAccountVerifiedNotification, cancellationToken); break; case PushType.ProviderBankAccountVerified: var providerBankAccountVerifiedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(providerBankAccountVerifiedNotification.Payload.AdminId.ToString()) + if (providerBankAccountVerifiedNotification is null) + { + break; + } + + await _hubContext.Clients.User(providerBankAccountVerifiedNotification.Payload.AdminId.ToString()) .SendAsync(_receiveMessageMethod, providerBankAccountVerifiedNotification, cancellationToken); break; case PushType.Notification: case PushType.NotificationStatus: var notificationData = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (notificationData is null) + { + break; + } + if (notificationData.Payload.InstallationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( notificationData.Payload.InstallationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } @@ -134,31 +201,38 @@ public static class HubHelpers { if (notificationData.Payload.ClientType == ClientType.All) { - await hubContext.Clients.User(notificationData.Payload.UserId.ToString()) + await _hubContext.Clients.User(notificationData.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } else { - await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + await _hubContext.Clients.Group(NotificationsHub.GetUserGroup( notificationData.Payload.UserId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } } else if (notificationData.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( notificationData.Payload.OrganizationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } break; case PushType.RefreshSecurityTasks: - var pendingTasksData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) + var pendingTasksData = + JsonSerializer.Deserialize>(notificationJson, + _deserializerOptions); + if (pendingTasksData is null) + { + break; + } + + await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); break; default: - logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); + _logger.LogWarning("Notification type '{NotificationType}' has not been registered in HubHelpers and will not be pushed as as result", notification.Type); break; } } diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index eb3c3f8682..2889e90d3b 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -61,6 +61,7 @@ public class Startup } services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Mvc services.AddMvc(); diff --git a/src/SharedWeb/SharedWeb.csproj b/src/SharedWeb/SharedWeb.csproj index 8bffa285fc..d8dc61178d 100644 --- a/src/SharedWeb/SharedWeb.csproj +++ b/src/SharedWeb/SharedWeb.csproj @@ -7,6 +7,7 @@ + diff --git a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs index 332aa6838c..aba1a6a8dc 100644 --- a/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs +++ b/src/SharedWeb/Utilities/ExceptionHandlerFilterAttribute.cs @@ -75,7 +75,7 @@ public class ExceptionHandlerFilterAttribute : ExceptionFilterAttribute else { var logger = context.HttpContext.RequestServices.GetRequiredService>(); - logger.LogError(0, exception, exception.Message); + logger.LogError(0, exception, "Unhandled exception"); errorMessage = "An unhandled server error has occurred."; context.HttpContext.Response.StatusCode = 500; } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index d87f9ab97f..78b8a61015 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -6,9 +6,12 @@ using System.Reflection; using System.Security.Claims; using System.Security.Cryptography.X509Certificates; using AspNetCoreRateLimit; +using Azure.Messaging.ServiceBus; +using Bit.Core; using Bit.Core.AdminConsole.AbilitiesCache; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.AdminConsole.Models.Teams; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.AdminConsole.Services.Implementations; @@ -35,6 +38,9 @@ using Bit.Core.KeyManagement; using Bit.Core.NotificationCenter; using Bit.Core.OrganizationFeatures; using Bit.Core.Platform; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; +using Bit.Core.Platform.Mail.Mailer; using Bit.Core.Platform.Push; using Bit.Core.Platform.PushRegistration.Internal; using Bit.Core.Repositories; @@ -43,6 +49,7 @@ using Bit.Core.SecretsManager.Repositories; using Bit.Core.SecretsManager.Repositories.Noop; using Bit.Core.Services; using Bit.Core.Services.Implementations; +using Bit.Core.Services.Mail; using Bit.Core.Settings; using Bit.Core.Tokens; using Bit.Core.Tools.ImportFeatures; @@ -67,6 +74,8 @@ using Microsoft.AspNetCore.HttpOverrides; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc.Localization; using Microsoft.Azure.Cosmos.Fluent; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; using Microsoft.Extensions.Caching.Cosmos; using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.Configuration; @@ -237,8 +246,11 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + // Legacy mailer service services.AddSingleton(); services.AddSingleton(); + // Modern mailers + services.AddMailer(); services.AddSingleton(); services.AddSingleton(_ => { @@ -512,42 +524,33 @@ public static class ServiceCollectionExtensions public static IServiceCollection AddEventWriteServices(this IServiceCollection services, GlobalSettings globalSettings) { - if (!globalSettings.SelfHosted && CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + if (IsAzureServiceBusEnabled(globalSettings)) { - services.TryAddKeyedSingleton("storage"); - - if (CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.ConnectionString) && - CoreHelpers.SettingHasValue(globalSettings.EventLogging.AzureServiceBus.EventTopicName)) - { - services.TryAddSingleton(); - services.TryAddKeyedSingleton("broadcast"); - } - else - { - services.TryAddKeyedSingleton("broadcast"); - } - } - else if (globalSettings.SelfHosted) - { - services.TryAddKeyedSingleton("storage"); - - if (IsRabbitMqEnabled(globalSettings)) - { - services.TryAddSingleton(); - services.TryAddKeyedSingleton("broadcast"); - } - else - { - services.TryAddKeyedSingleton("broadcast"); - } - } - else - { - services.TryAddKeyedSingleton("storage"); - services.TryAddKeyedSingleton("broadcast"); + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; } - services.TryAddScoped(); + if (IsRabbitMqEnabled(globalSettings)) + { + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + + if (CoreHelpers.SettingHasValue(globalSettings.Events.ConnectionString)) + { + services.TryAddSingleton(); + return services; + } + + if (globalSettings.SelfHosted) + { + services.TryAddSingleton(); + return services; + } + + services.TryAddSingleton(); return services; } @@ -602,6 +605,33 @@ public static class ServiceCollectionExtensions return services; } + public static IServiceCollection AddTeamsService(this IServiceCollection services, GlobalSettings globalSettings) + { + if (CoreHelpers.SettingHasValue(globalSettings.Teams.ClientId) && + CoreHelpers.SettingHasValue(globalSettings.Teams.ClientSecret) && + CoreHelpers.SettingHasValue(globalSettings.Teams.Scopes)) + { + services.AddHttpClient(TeamsService.HttpClientName); + services.TryAddSingleton(); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(sp => sp.GetRequiredService()); + services.TryAddSingleton(sp => + new BotFrameworkHttpAdapter( + new TeamsBotCredentialProvider( + clientId: globalSettings.Teams.ClientId, + clientSecret: globalSettings.Teams.ClientSecret + ) + ) + ); + } + else + { + services.TryAddSingleton(); + } + + return services; + } + public static void UseDefaultMiddleware(this IApplicationBuilder app, IWebHostEnvironment env, GlobalSettings globalSettings) { @@ -694,8 +724,23 @@ public static class ServiceCollectionExtensions { options.ServerDomain = new Uri(globalSettings.BaseServiceUri.Vault).Host; options.ServerName = "Bitwarden"; - options.Origins = new HashSet { globalSettings.BaseServiceUri.Vault, }; options.TimestampDriftTolerance = 300000; + + if (globalSettings.Fido2?.Origins?.Any() == true) + { + options.Origins = new HashSet(globalSettings.Fido2.Origins); + } + else + { + // Default to allowing the vault domain and chromium browser extension IDs + options.Origins = new HashSet { + globalSettings.BaseServiceUri.Vault, + Constants.BrowserExtensions.ChromeId, + Constants.BrowserExtensions.EdgeId, + Constants.BrowserExtensions.OperaId + }; + } + }); } @@ -855,6 +900,11 @@ public static class ServiceCollectionExtensions configuration: listenerConfiguration, handler: provider.GetRequiredKeyedService(serviceKey: listenerConfiguration.RoutingKey), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.EventPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.EventMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) @@ -865,6 +915,11 @@ public static class ServiceCollectionExtensions configuration: listenerConfiguration, handler: provider.GetRequiredService>(), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = listenerConfiguration.IntegrationPrefetchCount, + MaxConcurrentCalls = listenerConfiguration.IntegrationMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) @@ -886,6 +941,7 @@ public static class ServiceCollectionExtensions // Add services in support of handlers services.AddSlackService(globalSettings); + services.AddTeamsService(globalSettings); services.TryAddSingleton(TimeProvider.System); services.AddHttpClient(WebhookIntegrationHandler.HttpClientName); services.AddHttpClient(DatadogIntegrationHandler.HttpClientName); @@ -894,12 +950,14 @@ public static class ServiceCollectionExtensions services.TryAddSingleton, SlackIntegrationHandler>(); services.TryAddSingleton, WebhookIntegrationHandler>(); services.TryAddSingleton, DatadogIntegrationHandler>(); + services.TryAddSingleton, TeamsIntegrationHandler>(); var repositoryConfiguration = new RepositoryListenerConfiguration(globalSettings); var slackConfiguration = new SlackListenerConfiguration(globalSettings); var webhookConfiguration = new WebhookListenerConfiguration(globalSettings); var hecConfiguration = new HecListenerConfiguration(globalSettings); var datadogConfiguration = new DatadogListenerConfiguration(globalSettings); + var teamsConfiguration = new TeamsListenerConfiguration(globalSettings); if (IsRabbitMqEnabled(globalSettings)) { @@ -917,6 +975,7 @@ public static class ServiceCollectionExtensions services.AddRabbitMqIntegration(webhookConfiguration); services.AddRabbitMqIntegration(hecConfiguration); services.AddRabbitMqIntegration(datadogConfiguration); + services.AddRabbitMqIntegration(teamsConfiguration); } if (IsAzureServiceBusEnabled(globalSettings)) @@ -927,6 +986,11 @@ public static class ServiceCollectionExtensions configuration: repositoryConfiguration, handler: provider.GetRequiredService(), serviceBusService: provider.GetRequiredService(), + serviceBusOptions: new ServiceBusProcessorOptions() + { + PrefetchCount = repositoryConfiguration.EventPrefetchCount, + MaxConcurrentCalls = repositoryConfiguration.EventMaxConcurrentCalls + }, loggerFactory: provider.GetRequiredService() ) ) @@ -935,6 +999,7 @@ public static class ServiceCollectionExtensions services.AddAzureServiceBusIntegration(webhookConfiguration); services.AddAzureServiceBusIntegration(hecConfiguration); services.AddAzureServiceBusIntegration(datadogConfiguration); + services.AddAzureServiceBusIntegration(teamsConfiguration); } return services; diff --git a/src/Sql/Sql.sqlproj b/src/Sql/Sql.sqlproj index 1a7530321e..0622c5cbb2 100644 --- a/src/Sql/Sql.sqlproj +++ b/src/Sql/Sql.sqlproj @@ -17,7 +17,4 @@ 71502 - - - diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql index d6cd206558..397911549c 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Create.sql @@ -6,7 +6,19 @@ CREATE PROCEDURE [dbo].[OrganizationReport_Create] @ContentEncryptionKey VARCHAR(MAX), @SummaryData NVARCHAR(MAX), @ApplicationData NVARCHAR(MAX), - @RevisionDate DATETIME2(7) + @RevisionDate DATETIME2(7), + @ApplicationCount INT = NULL, + @ApplicationAtRiskCount INT = NULL, + @CriticalApplicationCount INT = NULL, + @CriticalApplicationAtRiskCount INT = NULL, + @MemberCount INT = NULL, + @MemberAtRiskCount INT = NULL, + @CriticalMemberCount INT = NULL, + @CriticalMemberAtRiskCount INT = NULL, + @PasswordCount INT = NULL, + @PasswordAtRiskCount INT = NULL, + @CriticalPasswordCount INT = NULL, + @CriticalPasswordAtRiskCount INT = NULL AS BEGIN SET NOCOUNT ON; @@ -20,7 +32,19 @@ INSERT INTO [dbo].[OrganizationReport]( [ContentEncryptionKey], [SummaryData], [ApplicationData], - [RevisionDate] + [RevisionDate], + [ApplicationCount], + [ApplicationAtRiskCount], + [CriticalApplicationCount], + [CriticalApplicationAtRiskCount], + [MemberCount], + [MemberAtRiskCount], + [CriticalMemberCount], + [CriticalMemberAtRiskCount], + [PasswordCount], + [PasswordAtRiskCount], + [CriticalPasswordCount], + [CriticalPasswordAtRiskCount] ) VALUES ( @Id, @@ -30,6 +54,18 @@ VALUES ( @ContentEncryptionKey, @SummaryData, @ApplicationData, - @RevisionDate + @RevisionDate, + @ApplicationCount, + @ApplicationAtRiskCount, + @CriticalApplicationCount, + @CriticalApplicationAtRiskCount, + @MemberCount, + @MemberAtRiskCount, + @CriticalMemberCount, + @CriticalMemberAtRiskCount, + @PasswordCount, + @PasswordAtRiskCount, + @CriticalPasswordCount, + @CriticalPasswordAtRiskCount ); END diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql index 1312369fa8..fca8788ce6 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_GetLatestByOrganizationId.sql @@ -5,14 +5,7 @@ BEGIN SET NOCOUNT ON SELECT TOP 1 - [Id], - [OrganizationId], - [ReportData], - [CreationDate], - [ContentEncryptionKey], - [SummaryData], - [ApplicationData], - [RevisionDate] + * FROM [dbo].[OrganizationReportView] WHERE [OrganizationId] = @OrganizationId ORDER BY [RevisionDate] DESC diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql index 4732fb8ef4..e78d25267d 100644 --- a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_Update.sql @@ -6,7 +6,19 @@ CREATE PROCEDURE [dbo].[OrganizationReport_Update] @ContentEncryptionKey VARCHAR(MAX), @SummaryData NVARCHAR(MAX), @ApplicationData NVARCHAR(MAX), - @RevisionDate DATETIME2(7) + @RevisionDate DATETIME2(7), + @ApplicationCount INT = NULL, + @ApplicationAtRiskCount INT = NULL, + @CriticalApplicationCount INT = NULL, + @CriticalApplicationAtRiskCount INT = NULL, + @MemberCount INT = NULL, + @MemberAtRiskCount INT = NULL, + @CriticalMemberCount INT = NULL, + @CriticalMemberAtRiskCount INT = NULL, + @PasswordCount INT = NULL, + @PasswordAtRiskCount INT = NULL, + @CriticalPasswordCount INT = NULL, + @CriticalPasswordAtRiskCount INT = NULL AS BEGIN SET NOCOUNT ON; @@ -18,6 +30,18 @@ BEGIN [ContentEncryptionKey] = @ContentEncryptionKey, [SummaryData] = @SummaryData, [ApplicationData] = @ApplicationData, - [RevisionDate] = @RevisionDate + [RevisionDate] = @RevisionDate, + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount WHERE [Id] = @Id; END; diff --git a/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql new file mode 100644 index 0000000000..8b06c90fe1 --- /dev/null +++ b/src/Sql/dbo/Dirt/Stored Procedures/OrganizationReport_UpdateMetrics.sql @@ -0,0 +1,39 @@ +CREATE PROCEDURE [dbo].[OrganizationReport_UpdateMetrics] + @Id UNIQUEIDENTIFIER, + @ApplicationCount INT, + @ApplicationAtRiskCount INT, + @CriticalApplicationCount INT, + @CriticalApplicationAtRiskCount INT, + @MemberCount INT, + @MemberAtRiskCount INT, + @CriticalMemberCount INT, + @CriticalMemberAtRiskCount INT, + @PasswordCount INT, + @PasswordAtRiskCount INT, + @CriticalPasswordCount INT, + @CriticalPasswordAtRiskCount INT, + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + UPDATE + [dbo].[OrganizationReport] + SET + [ApplicationCount] = @ApplicationCount, + [ApplicationAtRiskCount] = @ApplicationAtRiskCount, + [CriticalApplicationCount] = @CriticalApplicationCount, + [CriticalApplicationAtRiskCount] = @CriticalApplicationAtRiskCount, + [MemberCount] = @MemberCount, + [MemberAtRiskCount] = @MemberAtRiskCount, + [CriticalMemberCount] = @CriticalMemberCount, + [CriticalMemberAtRiskCount] = @CriticalMemberAtRiskCount, + [PasswordCount] = @PasswordCount, + [PasswordAtRiskCount] = @PasswordAtRiskCount, + [CriticalPasswordCount] = @CriticalPasswordCount, + [CriticalPasswordAtRiskCount] = @CriticalPasswordAtRiskCount, + [RevisionDate] = @RevisionDate + WHERE + [Id] = @Id + +END diff --git a/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql b/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql index 4c47eafad8..2ffedc3f41 100644 --- a/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql +++ b/src/Sql/dbo/Dirt/Tables/OrganizationReport.sql @@ -1,12 +1,24 @@ CREATE TABLE [dbo].[OrganizationReport] ( - [Id] UNIQUEIDENTIFIER NOT NULL, - [OrganizationId] UNIQUEIDENTIFIER NOT NULL, - [ReportData] NVARCHAR(MAX) NOT NULL, - [CreationDate] DATETIME2 (7) NOT NULL, - [ContentEncryptionKey] VARCHAR(MAX) NOT NULL, - [SummaryData] NVARCHAR(MAX) NULL, - [ApplicationData] NVARCHAR(MAX) NULL, - [RevisionDate] DATETIME2 (7) NULL, + [Id] UNIQUEIDENTIFIER NOT NULL, + [OrganizationId] UNIQUEIDENTIFIER NOT NULL, + [ReportData] NVARCHAR(MAX) NOT NULL, + [CreationDate] DATETIME2 (7) NOT NULL, + [ContentEncryptionKey] VARCHAR(MAX) NOT NULL, + [SummaryData] NVARCHAR(MAX) NULL, + [ApplicationData] NVARCHAR(MAX) NULL, + [RevisionDate] DATETIME2 (7) NULL, + [ApplicationCount] INT NULL, + [ApplicationAtRiskCount] INT NULL, + [CriticalApplicationCount] INT NULL, + [CriticalApplicationAtRiskCount] INT NULL, + [MemberCount] INT NULL, + [MemberAtRiskCount] INT NULL, + [CriticalMemberCount] INT NULL, + [CriticalMemberAtRiskCount] INT NULL, + [PasswordCount] INT NULL, + [PasswordAtRiskCount] INT NULL, + [CriticalPasswordCount] INT NULL, + [CriticalPasswordAtRiskCount] INT NULL, CONSTRAINT [PK_OrganizationReport] PRIMARY KEY CLUSTERED ([Id] ASC), CONSTRAINT [FK_OrganizationReport_Organization] FOREIGN KEY ([OrganizationId]) REFERENCES [dbo].[Organization] ([Id]) ); diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql new file mode 100644 index 0000000000..8bfa0156af --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_ReadByUserId.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_ReadByUserId] + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + SELECT + * + FROM + [dbo].[UserSignatureKeyPairView] + WHERE + [UserId] = @UserId; +END diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql new file mode 100644 index 0000000000..6ee33e2a40 --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_SetForRotation.sql @@ -0,0 +1,33 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_SetForRotation] + @Id UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER, + @SignatureAlgorithm TINYINT, + @SigningKey VARCHAR(MAX), + @VerifyingKey VARCHAR(MAX), + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + + INSERT INTO [dbo].[UserSignatureKeyPair] + ( + [Id], + [UserId], + [SignatureAlgorithm], + [SigningKey], + [VerifyingKey], + [CreationDate], + [RevisionDate] + ) + VALUES + ( + @Id, + @UserId, + @SignatureAlgorithm, + @SigningKey, + @VerifyingKey, + @CreationDate, + @RevisionDate + ) +END diff --git a/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql new file mode 100644 index 0000000000..4f673019fc --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Stored Procedures/UserSignatureKeyPair_UpdateForRotation.sql @@ -0,0 +1,19 @@ +CREATE PROCEDURE [dbo].[UserSignatureKeyPair_UpdateForRotation] + @UserId UNIQUEIDENTIFIER, + @SignatureAlgorithm TINYINT, + @SigningKey VARCHAR(MAX), + @VerifyingKey VARCHAR(MAX), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON; + UPDATE + [dbo].[UserSignatureKeyPair] + SET + [SignatureAlgorithm] = @SignatureAlgorithm, + [SigningKey] = @SigningKey, + [VerifyingKey] = @VerifyingKey, + [RevisionDate] = @RevisionDate + WHERE + [UserId] = @UserId; +END diff --git a/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql b/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql new file mode 100644 index 0000000000..94d4e48a0b --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Tables/UserSignatureKeyPair.sql @@ -0,0 +1,16 @@ +CREATE TABLE [dbo].[UserSignatureKeyPair] ( + [Id] UNIQUEIDENTIFIER NOT NULL, + [UserId] UNIQUEIDENTIFIER NOT NULL, + [SignatureAlgorithm] TINYINT NOT NULL, + [SigningKey] VARCHAR(MAX) NOT NULL, + [VerifyingKey] VARCHAR(MAX) NOT NULL, + [CreationDate] DATETIME2 (7) NOT NULL, + [RevisionDate] DATETIME2 (7) NOT NULL, + CONSTRAINT [PK_UserSignatureKeyPair] PRIMARY KEY CLUSTERED ([Id] ASC), + CONSTRAINT [FK_UserSignatureKeyPair_User] FOREIGN KEY ([UserId]) REFERENCES [dbo].[User] ([Id]) ON DELETE CASCADE +); +GO + +CREATE UNIQUE NONCLUSTERED INDEX [IX_UserSignatureKeyPair_UserId] + ON [dbo].[UserSignatureKeyPair]([UserId] ASC); +GO diff --git a/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql b/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql new file mode 100644 index 0000000000..959305a3e7 --- /dev/null +++ b/src/Sql/dbo/KeyManagement/Views/UserSignatureKeyPairView.sql @@ -0,0 +1,6 @@ +CREATE VIEW [dbo].[UserSignatureKeyPairView] +AS +SELECT + * +FROM + [dbo].[UserSignatureKeyPair] diff --git a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql index 5dc950ffff..831c9f70ee 100644 --- a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql +++ b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByOrganizationIdServiceAccountId.sql @@ -18,7 +18,7 @@ BEGIN AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate) AND (@BeforeDate IS NULL OR [Date] < @BeforeDate) AND [OrganizationId] = @OrganizationId - AND [ServiceAccountId] = @ServiceAccountId + AND ([ServiceAccountId] = @ServiceAccountId OR [GrantedServiceAccountId] = @ServiceAccountId) ORDER BY [Date] DESC OFFSET 0 ROWS FETCH NEXT @PageSize ROWS ONLY diff --git a/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql new file mode 100644 index 0000000000..c429a4a064 --- /dev/null +++ b/src/Sql/dbo/SecretsManager/Stored Procedures/Event/Event_ReadPageByServiceAccountId.sql @@ -0,0 +1,45 @@ +CREATE PROCEDURE [dbo].[Event_ReadPageByServiceAccountId] + @GrantedServiceAccountId UNIQUEIDENTIFIER, + @StartDate DATETIME2(7), + @EndDate DATETIME2(7), + @BeforeDate DATETIME2(7), + @PageSize INT +AS +BEGIN + SET NOCOUNT ON + + SELECT + e.Id, + e.Date, + e.Type, + e.UserId, + e.OrganizationId, + e.InstallationId, + e.ProviderId, + e.CipherId, + e.CollectionId, + e.PolicyId, + e.GroupId, + e.OrganizationUserId, + e.ProviderUserId, + e.ProviderOrganizationId, + e.DeviceType, + e.IpAddress, + e.ActingUserId, + e.SystemUser, + e.DomainName, + e.SecretId, + e.ServiceAccountId, + e.ProjectId, + e.GrantedServiceAccountId + FROM + [dbo].[EventView] e + WHERE + [Date] >= @StartDate + AND (@BeforeDate IS NOT NULL OR [Date] <= @EndDate) + AND (@BeforeDate IS NULL OR [Date] < @BeforeDate) + AND [GrantedServiceAccountId] = @GrantedServiceAccountId + ORDER BY [Date] DESC + OFFSET 0 ROWS + FETCH NEXT @PageSize ROWS ONLY +END diff --git a/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql b/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql new file mode 100644 index 0000000000..31ab443f56 --- /dev/null +++ b/src/Sql/dbo/SecretsManager/Tables/SecretVersion.sql @@ -0,0 +1,27 @@ +CREATE TABLE [dbo].[SecretVersion] ( + [Id] UNIQUEIDENTIFIER NOT NULL, + [SecretId] UNIQUEIDENTIFIER NOT NULL, + [Value] NVARCHAR (MAX) NOT NULL, + [VersionDate] DATETIME2 (7) NOT NULL, + [EditorServiceAccountId] UNIQUEIDENTIFIER NULL, + [EditorOrganizationUserId] UNIQUEIDENTIFIER NULL, + CONSTRAINT [PK_SecretVersion] PRIMARY KEY CLUSTERED ([Id] ASC), + CONSTRAINT [FK_SecretVersion_OrganizationUser] FOREIGN KEY ([EditorOrganizationUserId]) REFERENCES [dbo].[OrganizationUser] ([Id]) ON DELETE SET NULL, + CONSTRAINT [FK_SecretVersion_Secret] FOREIGN KEY ([SecretId]) REFERENCES [dbo].[Secret] ([Id]) ON DELETE CASCADE, + CONSTRAINT [FK_SecretVersion_ServiceAccount] FOREIGN KEY ([EditorServiceAccountId]) REFERENCES [dbo].[ServiceAccount] ([Id]) ON DELETE SET NULL +); + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_SecretId] + ON [dbo].[SecretVersion]([SecretId] ASC); + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_EditorServiceAccountId] + ON [dbo].[SecretVersion]([EditorServiceAccountId] ASC) + WHERE [EditorServiceAccountId] IS NOT NULL; + +GO +CREATE NONCLUSTERED INDEX [IX_SecretVersion_EditorOrganizationUserId] + ON [dbo].[SecretVersion]([EditorOrganizationUserId] ASC) + WHERE [EditorOrganizationUserId] IS NOT NULL; +GO \ No newline at end of file diff --git a/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollections.sql b/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollections.sql index f3a1d964b5..2282524228 100644 --- a/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollections.sql +++ b/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollections.sql @@ -44,13 +44,13 @@ BEGIN [CollectionId], [CipherId] ) - SELECT + SELECT [Id], @CipherId FROM @CollectionIds WHERE [Id] IN (SELECT [Id] FROM [#TempAvailableCollections]) AND NOT EXISTS ( - SELECT 1 + SELECT 1 FROM [dbo].[CollectionCipher] WHERE [CollectionId] = [@CollectionIds].[Id] AND [CipherId] = @CipherId diff --git a/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollectionsAdmin.sql b/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollectionsAdmin.sql index 5f7b0215d9..1486709f09 100644 --- a/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollectionsAdmin.sql +++ b/src/Sql/dbo/Stored Procedures/CollectionCipher_UpdateCollectionsAdmin.sql @@ -4,46 +4,52 @@ @CollectionIds AS [dbo].[GuidIdArray] READONLY AS BEGIN - SET NOCOUNT ON + SET NOCOUNT ON; - ;WITH [AvailableCollectionsCTE] AS( - SELECT - Id - FROM - [dbo].[Collection] - WHERE - OrganizationId = @OrganizationId - ), - [CollectionCiphersCTE] AS( - SELECT - [CollectionId], - [CipherId] - FROM - [dbo].[CollectionCipher] - WHERE - [CipherId] = @CipherId + -- Available collections for this org, excluding default collections + SELECT + C.[Id] + INTO #TempAvailableCollections + FROM [dbo].[Collection] AS C + WHERE + C.[OrganizationId] = @OrganizationId + AND C.[Type] <> 1; -- exclude DefaultUserCollection + + -- Insert new collection assignments + INSERT INTO [dbo].[CollectionCipher] ( + [CollectionId], + [CipherId] ) - MERGE - [CollectionCiphersCTE] AS [Target] - USING - @CollectionIds AS [Source] - ON - [Target].[CollectionId] = [Source].[Id] - AND [Target].[CipherId] = @CipherId - WHEN NOT MATCHED BY TARGET - AND [Source].[Id] IN (SELECT [Id] FROM [AvailableCollectionsCTE]) THEN - INSERT VALUES - ( - [Source].[Id], - @CipherId - ) - WHEN NOT MATCHED BY SOURCE - AND [Target].[CipherId] = @CipherId THEN - DELETE - ; + SELECT + S.[Id], + @CipherId + FROM @CollectionIds AS S + INNER JOIN #TempAvailableCollections AS A + ON A.[Id] = S.[Id] + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionCipher] AS CC + WHERE CC.[CollectionId] = S.[Id] + AND CC.[CipherId] = @CipherId + ); + + -- Delete removed collection assignments + DELETE CC + FROM [dbo].[CollectionCipher] AS CC + INNER JOIN #TempAvailableCollections AS A + ON A.[Id] = CC.[CollectionId] + WHERE CC.[CipherId] = @CipherId + AND NOT EXISTS ( + SELECT 1 + FROM @CollectionIds AS S + WHERE S.[Id] = CC.[CollectionId] + ); IF @OrganizationId IS NOT NULL BEGIN - EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationId] @OrganizationId + EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationId] @OrganizationId; END -END \ No newline at end of file + + DROP TABLE #TempAvailableCollections; +END +GO diff --git a/src/Sql/dbo/Stored Procedures/Event_Create.sql b/src/Sql/dbo/Stored Procedures/Event_Create.sql index 89971bd56f..0466bc1a69 100644 --- a/src/Sql/dbo/Stored Procedures/Event_Create.sql +++ b/src/Sql/dbo/Stored Procedures/Event_Create.sql @@ -20,7 +20,8 @@ @DomainName VARCHAR(256), @SecretId UNIQUEIDENTIFIER = null, @ServiceAccountId UNIQUEIDENTIFIER = null, - @ProjectId UNIQUEIDENTIFIER = null + @ProjectId UNIQUEIDENTIFIER = null, + @GrantedServiceAccountId UNIQUEIDENTIFIER = null AS BEGIN SET NOCOUNT ON @@ -48,7 +49,8 @@ BEGIN [DomainName], [SecretId], [ServiceAccountId], - [ProjectId] + [ProjectId], + [GrantedServiceAccountId] ) VALUES ( @@ -73,6 +75,7 @@ BEGIN @DomainName, @SecretId, @ServiceAccountId, - @ProjectId + @ProjectId, + @GrantedServiceAccountId ) END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql b/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql new file mode 100644 index 0000000000..8e2102772b --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId.sql @@ -0,0 +1,17 @@ +CREATE PROCEDURE [dbo].[OrganizationIntegration_ReadByTeamsConfigurationTenantIdTeamId] + @TenantId NVARCHAR(200), + @TeamId NVARCHAR(200) +AS +BEGIN + SET NOCOUNT ON; + +SELECT TOP 1 * +FROM [dbo].[OrganizationIntegrationView] + CROSS APPLY OPENJSON([Configuration], '$.Teams') + WITH ( TeamId NVARCHAR(MAX) '$.id' ) t +WHERE [Type] = 7 + AND JSON_VALUE([Configuration], '$.TenantId') = @TenantId + AND t.TeamId = @TeamId + AND JSON_VALUE([Configuration], '$.ChannelId') IS NULL + AND JSON_VALUE([Configuration], '$.ServiceUrl') IS NULL; +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql new file mode 100644 index 0000000000..7a1cd78a51 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ConfirmById.sql @@ -0,0 +1,30 @@ +CREATE PROCEDURE [dbo].[OrganizationUser_ConfirmById] + @Id UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER, + @RevisionDate DATETIME2(7), + @Key NVARCHAR(MAX) = NULL +AS +BEGIN + SET NOCOUNT ON + + DECLARE @RowCount INT; + + UPDATE + [dbo].[OrganizationUser] + SET + [Status] = 2, -- Set to Confirmed + [RevisionDate] = @RevisionDate, + [Key] = @Key + WHERE + [Id] = @Id + AND [Status] = 1 -- Only update if status is Accepted + + SET @RowCount = @@ROWCOUNT; + + IF @RowCount > 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDate] @UserId + END + + SELECT @RowCount; +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql index c2bc690a27..105170cd27 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByUserIdWithPolicyDetails.sql @@ -4,31 +4,70 @@ AS BEGIN SET NOCOUNT ON -SELECT - OU.[Id] AS OrganizationUserId, - P.[OrganizationId], - P.[Type] AS PolicyType, - P.[Enabled] AS PolicyEnabled, - P.[Data] AS PolicyData, - OU.[Type] AS OrganizationUserType, - OU.[Status] AS OrganizationUserStatus, - OU.[Permissions] AS OrganizationUserPermissionsData, - CASE WHEN EXISTS ( - SELECT 1 - FROM [dbo].[ProviderUserView] PU - INNER JOIN [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId] - WHERE PU.[UserId] = OU.[UserId] AND PO.[OrganizationId] = P.[OrganizationId] - ) THEN 1 ELSE 0 END AS IsProvider -FROM [dbo].[PolicyView] P -INNER JOIN [dbo].[OrganizationUserView] OU - ON P.[OrganizationId] = OU.[OrganizationId] -WHERE P.[Type] = @PolicyType AND + + DECLARE @UserEmail NVARCHAR(256) + SELECT @UserEmail = Email + FROM + [dbo].[UserView] + WHERE + Id = @UserId + + ;WITH OrgUsers AS ( - (OU.[Status] != 0 AND OU.[UserId] = @UserId) -- OrgUsers who have accepted their invite and are linked to a UserId - OR EXISTS ( - SELECT 1 - FROM [dbo].[UserView] U - WHERE U.[Id] = @UserId AND OU.[Email] = U.[Email] AND OU.[Status] = 0 -- 'Invited' OrgUsers are not linked to a UserId yet, so we have to look up their email - ) + -- All users except invited (Status <> 0): direct UserId match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] <> 0 + AND OU.[UserId] = @UserId + + UNION ALL + + -- Invited users: email match + SELECT + OU.[Id], + OU.[OrganizationId], + OU.[Type], + OU.[Status], + OU.[Permissions] + FROM + [dbo].[OrganizationUserView] OU + WHERE + OU.[Status] = 0 + AND OU.[Email] = @UserEmail + AND @UserEmail IS NOT NULL + ), + Providers AS + ( + SELECT + OrganizationId + FROM + [dbo].[UserProviderAccessView] + WHERE + UserId = @UserId ) -END \ No newline at end of file + SELECT + OU.[Id] AS [OrganizationUserId], + P.[OrganizationId], + P.[Type] AS [PolicyType], + P.[Enabled] AS [PolicyEnabled], + P.[Data] AS [PolicyData], + OU.[Type] AS [OrganizationUserType], + OU.[Status] AS [OrganizationUserStatus], + OU.[Permissions] AS [OrganizationUserPermissionsData], + CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS [IsProvider] + FROM + [dbo].[PolicyView] P + INNER JOIN + OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId] + LEFT JOIN + Providers PR ON PR.[OrganizationId] = OU.[OrganizationId] + WHERE + P.[Type] = @PolicyType +END diff --git a/src/Sql/dbo/Stored Procedures/Organization_Create.sql b/src/Sql/dbo/Stored Procedures/Organization_Create.sql index 295ebb51a8..e37fa0e940 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Create.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Create.sql @@ -58,7 +58,8 @@ CREATE PROCEDURE [dbo].[Organization_Create] @LimitItemDeletion BIT = 0, @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, - @SyncSeats BIT = 0 + @SyncSeats BIT = 0, + @UseAutomaticUserConfirmation BIT = 0 AS BEGIN SET NOCOUNT ON @@ -124,69 +125,71 @@ BEGIN [LimitItemDeletion], [UseOrganizationDomains], [UseAdminSponsoredFamilies], - [SyncSeats] + [SyncSeats], + [UseAutomaticUserConfirmation] ) VALUES - ( - @Id, - @Identifier, - @Name, - @BusinessName, - @BusinessAddress1, - @BusinessAddress2, - @BusinessAddress3, - @BusinessCountry, - @BusinessTaxNumber, - @BillingEmail, - @Plan, - @PlanType, - @Seats, - @MaxCollections, - @UsePolicies, - @UseSso, - @UseGroups, - @UseDirectory, - @UseEvents, - @UseTotp, - @Use2fa, - @UseApi, - @UseResetPassword, - @SelfHost, - @UsersGetPremium, - @Storage, - @MaxStorageGb, - @Gateway, - @GatewayCustomerId, - @GatewaySubscriptionId, - @ReferenceData, - @Enabled, - @LicenseKey, - @PublicKey, - @PrivateKey, - @TwoFactorProviders, - @ExpirationDate, - @CreationDate, - @RevisionDate, - @OwnersNotifiedOfAutoscaling, - @MaxAutoscaleSeats, - @UseKeyConnector, - @UseScim, - @UseCustomPermissions, - @UseSecretsManager, - @Status, - @UsePasswordManager, - @SmSeats, - @SmServiceAccounts, - @MaxAutoscaleSmSeats, - @MaxAutoscaleSmServiceAccounts, - @SecretsManagerBeta, - @LimitCollectionCreation, - @LimitCollectionDeletion, - @AllowAdminAccessToAllCollectionItems, - @UseRiskInsights, - @LimitItemDeletion, - @UseOrganizationDomains, - @UseAdminSponsoredFamilies, - @SyncSeats - ) + ( + @Id, + @Identifier, + @Name, + @BusinessName, + @BusinessAddress1, + @BusinessAddress2, + @BusinessAddress3, + @BusinessCountry, + @BusinessTaxNumber, + @BillingEmail, + @Plan, + @PlanType, + @Seats, + @MaxCollections, + @UsePolicies, + @UseSso, + @UseGroups, + @UseDirectory, + @UseEvents, + @UseTotp, + @Use2fa, + @UseApi, + @UseResetPassword, + @SelfHost, + @UsersGetPremium, + @Storage, + @MaxStorageGb, + @Gateway, + @GatewayCustomerId, + @GatewaySubscriptionId, + @ReferenceData, + @Enabled, + @LicenseKey, + @PublicKey, + @PrivateKey, + @TwoFactorProviders, + @ExpirationDate, + @CreationDate, + @RevisionDate, + @OwnersNotifiedOfAutoscaling, + @MaxAutoscaleSeats, + @UseKeyConnector, + @UseScim, + @UseCustomPermissions, + @UseSecretsManager, + @Status, + @UsePasswordManager, + @SmSeats, + @SmServiceAccounts, + @MaxAutoscaleSmSeats, + @MaxAutoscaleSmServiceAccounts, + @SecretsManagerBeta, + @LimitCollectionCreation, + @LimitCollectionDeletion, + @AllowAdminAccessToAllCollectionItems, + @UseRiskInsights, + @LimitItemDeletion, + @UseOrganizationDomains, + @UseAdminSponsoredFamilies, + @SyncSeats, + @UseAutomaticUserConfirmation + ); END diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql index 6a8ed9e0d0..59226e59db 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadAbilities.sql @@ -27,7 +27,8 @@ BEGIN [UseRiskInsights], [LimitItemDeletion], [UseOrganizationDomains], - [UseAdminSponsoredFamilies] + [UseAdminSponsoredFamilies], + [UseAutomaticUserConfirmation] FROM [dbo].[Organization] END diff --git a/src/Sql/dbo/Stored Procedures/Organization_Update.sql b/src/Sql/dbo/Stored Procedures/Organization_Update.sql index d60852bab6..4807c7bb50 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_Update.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_Update.sql @@ -58,7 +58,8 @@ CREATE PROCEDURE [dbo].[Organization_Update] @LimitItemDeletion BIT = 0, @UseOrganizationDomains BIT = 0, @UseAdminSponsoredFamilies BIT = 0, - @SyncSeats BIT = 0 + @SyncSeats BIT = 0, + @UseAutomaticUserConfirmation BIT = 0 AS BEGIN SET NOCOUNT ON @@ -124,7 +125,8 @@ BEGIN [LimitItemDeletion] = @LimitItemDeletion, [UseOrganizationDomains] = @UseOrganizationDomains, [UseAdminSponsoredFamilies] = @UseAdminSponsoredFamilies, - [SyncSeats] = @SyncSeats + [SyncSeats] = @SyncSeats, + [UseAutomaticUserConfirmation] = @UseAutomaticUserConfirmation WHERE - [Id] = @Id + [Id] = @Id; END diff --git a/src/Sql/dbo/Stored Procedures/User_Create.sql b/src/Sql/dbo/Stored Procedures/User_Create.sql index 60d9b5eb32..2573bf1a0a 100644 --- a/src/Sql/dbo/Stored Procedures/User_Create.sql +++ b/src/Sql/dbo/Stored Procedures/User_Create.sql @@ -41,7 +41,10 @@ @LastKdfChangeDate DATETIME2(7) = NULL, @LastKeyRotationDate DATETIME2(7) = NULL, @LastEmailChangeDate DATETIME2(7) = NULL, - @VerifyDevices BIT = 1 + @VerifyDevices BIT = 1, + @SecurityState VARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignedPublicKey VARCHAR(MAX) = NULL AS BEGIN SET NOCOUNT ON @@ -90,7 +93,10 @@ BEGIN [LastKdfChangeDate], [LastKeyRotationDate], [LastEmailChangeDate], - [VerifyDevices] + [VerifyDevices], + [SecurityState], + [SecurityVersion], + [SignedPublicKey] ) VALUES ( @@ -136,6 +142,9 @@ BEGIN @LastKdfChangeDate, @LastKeyRotationDate, @LastEmailChangeDate, - @VerifyDevices + @VerifyDevices, + @SecurityState, + @SecurityVersion, + @SignedPublicKey ) END diff --git a/src/Sql/dbo/Stored Procedures/User_DeleteById.sql b/src/Sql/dbo/Stored Procedures/User_DeleteById.sql index 0608982e37..6377166e17 100644 --- a/src/Sql/dbo/Stored Procedures/User_DeleteById.sql +++ b/src/Sql/dbo/Stored Procedures/User_DeleteById.sql @@ -52,6 +52,16 @@ BEGIN WHERE [UserId] = @Id + -- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records + DECLARE @OrgUserIds [dbo].[GuidIdArray] + INSERT INTO @OrgUserIds (Id) + SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] = @Id + + IF EXISTS (SELECT 1 FROM @OrgUserIds) + BEGIN + EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds + END + -- Delete collection users DELETE CU diff --git a/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql b/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql index 97ab955f83..cdf3dd7d3a 100644 --- a/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql +++ b/src/Sql/dbo/Stored Procedures/User_DeleteByIds.sql @@ -66,6 +66,16 @@ BEGIN WHERE [UserId] IN (SELECT * FROM @ParsedIds) + -- Migrate DefaultUserCollection to SharedCollection before deleting CollectionUser records + DECLARE @OrgUserIds [dbo].[GuidIdArray] + INSERT INTO @OrgUserIds (Id) + SELECT [Id] FROM [dbo].[OrganizationUser] WHERE [UserId] IN (SELECT * FROM @ParsedIds) + + IF EXISTS (SELECT 1 FROM @OrgUserIds) + BEGIN + EXEC [dbo].[OrganizationUser_MigrateDefaultCollection] @OrgUserIds + END + -- Delete collection users DELETE CU diff --git a/src/Sql/dbo/Stored Procedures/User_Update.sql b/src/Sql/dbo/Stored Procedures/User_Update.sql index 15d04d72f6..5097bc538e 100644 --- a/src/Sql/dbo/Stored Procedures/User_Update.sql +++ b/src/Sql/dbo/Stored Procedures/User_Update.sql @@ -41,7 +41,10 @@ @LastKdfChangeDate DATETIME2(7) = NULL, @LastKeyRotationDate DATETIME2(7) = NULL, @LastEmailChangeDate DATETIME2(7) = NULL, - @VerifyDevices BIT = 1 + @VerifyDevices BIT = 1, + @SecurityState VARCHAR(MAX) = NULL, + @SecurityVersion INT = NULL, + @SignedPublicKey VARCHAR(MAX) = NULL AS BEGIN SET NOCOUNT ON @@ -90,7 +93,10 @@ BEGIN [LastKdfChangeDate] = @LastKdfChangeDate, [LastKeyRotationDate] = @LastKeyRotationDate, [LastEmailChangeDate] = @LastEmailChangeDate, - [VerifyDevices] = @VerifyDevices + [VerifyDevices] = @VerifyDevices, + [SecurityState] = @SecurityState, + [SecurityVersion] = @SecurityVersion, + [SignedPublicKey] = @SignedPublicKey WHERE [Id] = @Id END diff --git a/src/Sql/dbo/Tables/Event.sql b/src/Sql/dbo/Tables/Event.sql index 6dfb4392a0..ea0dda5661 100644 --- a/src/Sql/dbo/Tables/Event.sql +++ b/src/Sql/dbo/Tables/Event.sql @@ -21,11 +21,12 @@ [SecretId] UNIQUEIDENTIFIER NULL, [ServiceAccountId] UNIQUEIDENTIFIER NULL, [ProjectId] UNIQUEIDENTIFIER NULL, + [GrantedServiceAccountId] UNIQUEIDENTIFIER NULL, CONSTRAINT [PK_Event] PRIMARY KEY CLUSTERED ([Id] ASC) ); GO CREATE NONCLUSTERED INDEX [IX_Event_DateOrganizationIdUserId] - ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC); + ON [dbo].[Event]([Date] DESC, [OrganizationId] ASC, [ActingUserId] ASC, [CipherId] ASC) INCLUDE ([ServiceAccountId], [GrantedServiceAccountId]); diff --git a/src/Sql/dbo/Tables/Organization.sql b/src/Sql/dbo/Tables/Organization.sql index 897abef1cf..e1ad6863af 100644 --- a/src/Sql/dbo/Tables/Organization.sql +++ b/src/Sql/dbo/Tables/Organization.sql @@ -59,6 +59,7 @@ CREATE TABLE [dbo].[Organization] ( [UseOrganizationDomains] BIT NOT NULL CONSTRAINT [DF_Organization_UseOrganizationDomains] DEFAULT (0), [UseAdminSponsoredFamilies] BIT NOT NULL CONSTRAINT [DF_Organization_UseAdminSponsoredFamilies] DEFAULT (0), [SyncSeats] BIT NOT NULL CONSTRAINT [DF_Organization_SyncSeats] DEFAULT (0), + [UseAutomaticUserConfirmation] BIT NOT NULL CONSTRAINT [DF_Organization_UseAutomaticUserConfirmation] DEFAULT (0), CONSTRAINT [PK_Organization] PRIMARY KEY CLUSTERED ([Id] ASC) ); diff --git a/src/Sql/dbo/Tables/User.sql b/src/Sql/dbo/Tables/User.sql index 239ee67f11..dc772ff1a7 100644 --- a/src/Sql/dbo/Tables/User.sql +++ b/src/Sql/dbo/Tables/User.sql @@ -42,6 +42,9 @@ [LastKeyRotationDate] DATETIME2 (7) NULL, [LastEmailChangeDate] DATETIME2 (7) NULL, [VerifyDevices] BIT DEFAULT ((1)) NOT NULL, + [SecurityState] VARCHAR (MAX) NULL, + [SecurityVersion] INT NULL, + [SignedPublicKey] VARCHAR (MAX) NULL, CONSTRAINT [PK_User] PRIMARY KEY CLUSTERED ([Id] ASC) ); diff --git a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql index ac7be1bbae..c6816a1226 100644 --- a/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql +++ b/src/Sql/dbo/Vault/Stored Procedures/Cipher/Cipher_CreateWithCollections.sql @@ -23,4 +23,10 @@ BEGIN DECLARE @UpdateCollectionsSuccess INT EXEC @UpdateCollectionsSuccess = [dbo].[Cipher_UpdateCollections] @Id, @UserId, @OrganizationId, @CollectionIds + + -- Bump the account revision date AFTER collections are assigned. + IF @UpdateCollectionsSuccess = 0 + BEGIN + EXEC [dbo].[User_BumpAccountRevisionDateByCipherId] @Id, @OrganizationId + END END diff --git a/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql new file mode 100644 index 0000000000..8e00d06e43 --- /dev/null +++ b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds] + @CipherIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[SecurityTask] + SET + [Status] = 1, -- completed + [RevisionDate] = SYSUTCDATETIME() + WHERE + [CipherId] IN (SELECT [Id] FROM @CipherIds) + AND [Status] <> 1 -- Not already completed +END diff --git a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql index ba7e765569..a7e1db6e81 100644 --- a/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/OrganizationUserOrganizationDetailsView.sql @@ -54,7 +54,8 @@ SELECT O.[LimitItemDeletion], O.[UseAdminSponsoredFamilies], O.[UseOrganizationDomains], - OS.[IsAdminInitiated] + OS.[IsAdminInitiated], + O.[UseAutomaticUserConfirmation] FROM [dbo].[OrganizationUser] OU LEFT JOIN diff --git a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql index bd2485b411..42e877ab15 100644 --- a/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql +++ b/src/Sql/dbo/Views/ProviderUserProviderOrganizationDetailsView.sql @@ -16,6 +16,8 @@ SELECT O.[Use2fa], O.[UseApi], O.[UseResetPassword], + O.[UseSecretsManager], + O.[UsePasswordManager], O.[SelfHost], O.[UsersGetPremium], O.[UseCustomPermissions], @@ -39,7 +41,10 @@ SELECT O.[UseAdminSponsoredFamilies], P.[Type] ProviderType, O.[LimitItemDeletion], - O.[UseOrganizationDomains] + O.[UseOrganizationDomains], + O.[UseAutomaticUserConfirmation], + SS.[Enabled] SsoEnabled, + SS.[Data] SsoConfig FROM [dbo].[ProviderUser] PU INNER JOIN @@ -48,3 +53,5 @@ INNER JOIN [dbo].[Organization] O ON O.[Id] = PO.[OrganizationId] INNER JOIN [dbo].[Provider] P ON P.[Id] = PU.[ProviderId] +LEFT JOIN + [dbo].[SsoConfig] SS ON SS.[OrganizationId] = O.[Id] diff --git a/src/Sql/dbo/Views/UserProviderAccessView.sql b/src/Sql/dbo/Views/UserProviderAccessView.sql new file mode 100644 index 0000000000..dedc380311 --- /dev/null +++ b/src/Sql/dbo/Views/UserProviderAccessView.sql @@ -0,0 +1,9 @@ +CREATE VIEW [dbo].[UserProviderAccessView] +AS +SELECT DISTINCT + PU.[UserId], + PO.[OrganizationId] +FROM + [dbo].[ProviderUserView] PU +INNER JOIN + [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId] diff --git a/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index 44ad5088cd..84ef5c7f3d 100644 --- a/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Admin.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -1,5 +1,7 @@ using Bit.Admin.AdminConsole.Controllers; using Bit.Admin.AdminConsole.Models; +using Bit.Admin.Enums; +using Bit.Admin.Services; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.AdminConsole.Enums.Provider; @@ -276,5 +278,40 @@ public class OrganizationsControllerTests await providerBillingService.Received(1).ScaleSeats(provider, update.PlanType!.Value, update.Seats!.Value - organization.Seats.Value + organization.Seats.Value); } + [BitAutoData] + [SutProviderCustomize] + [Theory] + public async Task Edit_UseAutomaticUserConfirmation_FullUpdate_SavesFeatureCorrectly( + Organization organization, + SutProvider sutProvider) + { + // Arrange + var update = new OrganizationEditModel + { + PlanType = PlanType.TeamsMonthly, + UseAutomaticUserConfirmation = true + }; + + organization.UseAutomaticUserConfirmation = false; + + sutProvider.GetDependency() + .UserHasPermission(Permission.Org_Plan_Edit) + .Returns(true); + + var organizationRepository = sutProvider.GetDependency(); + + organizationRepository.GetByIdAsync(organization.Id).Returns(organization); + + // Act + _ = await sutProvider.Sut.Edit(organization.Id, update); + + // Assert + await organizationRepository.Received(1).ReplaceAsync(Arg.Is(o => o.Id == organization.Id + && o.UseAutomaticUserConfirmation == true)); + + // Annul + await organizationRepository.DeleteAsync(organization); + } + #endregion } diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs new file mode 100644 index 0000000000..cf842d1568 --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPutResetPasswordTests.cs @@ -0,0 +1,197 @@ +using System.Net; +using Bit.Api.AdminConsole.Authorization; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.Models.Request.Organizations; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Entities.Provider; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Models.Api; +using Bit.Core.Repositories; +using Bit.Core.Services; +using NSubstitute; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Controllers; + +public class OrganizationUsersControllerPutResetPasswordTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + private Organization _organization = null!; + private string _ownerEmail = null!; + + public OrganizationUsersControllerPutResetPasswordTests(ApiApplicationFactory apiFactory) + { + _factory = apiFactory; + _factory.SubstituteService(featureService => + { + featureService + .IsEnabled(FeatureFlagKeys.AccountRecoveryCommand) + .Returns(true); + }); + _client = _factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"reset-password-test-{Guid.NewGuid()}@example.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + ownerEmail: _ownerEmail, passwordManagerSeats: 5, paymentMethod: PaymentMethodType.Card); + + // Enable reset password and policies for the organization + var organizationRepository = _factory.GetService(); + _organization.UseResetPassword = true; + _organization.UsePolicies = true; + await organizationRepository.ReplaceAsync(_organization); + + // Enable the ResetPassword policy + var policyRepository = _factory.GetService(); + await policyRepository.CreateAsync(new Policy + { + OrganizationId = _organization.Id, + Type = PolicyType.ResetPassword, + Enabled = true, + Data = "{}" + }); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + /// + /// Helper method to set the ResetPasswordKey on an organization user, which is required for account recovery + /// + private async Task SetResetPasswordKeyAsync(OrganizationUser orgUser) + { + var organizationUserRepository = _factory.GetService(); + orgUser.ResetPasswordKey = "encrypted-reset-password-key"; + await organizationUserRepository.ReplaceAsync(orgUser); + } + + [Fact] + public async Task PutResetPassword_AsHigherRole_CanRecoverLowerRole() + { + // Arrange + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + var (_, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutResetPassword_AsLowerRole_CannotRecoverHigherRole() + { + // Arrange + var (adminEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Admin); + await _loginHelper.LoginAsync(adminEmail); + + var (_, targetOwnerOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.Owner); + await SetResetPasswordKeyAsync(targetOwnerOrgUser); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOwnerOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Contains(RecoverAccountAuthorizationHandler.FailureReason, model.Message); + } + + [Fact] + public async Task PutResetPassword_CannotRecoverProviderAccount() + { + // Arrange - Create owner who will try to recover the provider account + var (ownerEmail, _) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, + _organization.Id, OrganizationUserType.Owner); + await _loginHelper.LoginAsync(ownerEmail); + + // Create a user who is also a provider user + var (targetUserEmail, targetOrgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + await SetResetPasswordKeyAsync(targetOrgUser); + + // Add the target user as a provider user to a different provider + var providerRepository = _factory.GetService(); + var providerUserRepository = _factory.GetService(); + var userRepository = _factory.GetService(); + + var provider = await providerRepository.CreateAsync(new Provider + { + Name = "Test Provider", + BusinessName = "Test Provider Business", + BillingEmail = "provider@example.com", + Type = ProviderType.Msp, + Status = ProviderStatusType.Created, + Enabled = true + }); + + var targetUser = await userRepository.GetByEmailAsync(targetUserEmail); + Assert.NotNull(targetUser); + + await providerUserRepository.CreateAsync(new ProviderUser + { + ProviderId = provider.Id, + UserId = targetUser.Id, + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin + }); + + var resetPasswordRequest = new OrganizationUserResetPasswordRequestModel + { + NewMasterPasswordHash = "new-master-password-hash", + Key = "encrypted-recovery-key" + }; + + // Act + var response = await _client.PutAsJsonAsync( + $"organizations/{_organization.Id}/users/{targetOrgUser.Id}/reset-password", + resetPasswordRequest); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var model = await response.Content.ReadFromJsonAsync(); + Assert.Equal(RecoverAccountAuthorizationHandler.ProviderFailureReason, model.Message); + } +} diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs index 1efc2f843d..e4098ce9a9 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs @@ -67,7 +67,6 @@ public class PoliciesControllerTests : IClassFixture, IAs { Policy = new PolicyRequestModel { - Type = policyType, Enabled = true, }, Metadata = new Dictionary @@ -148,7 +147,6 @@ public class PoliciesControllerTests : IClassFixture, IAs { Policy = new PolicyRequestModel { - Type = policyType, Enabled = true, Data = new Dictionary { @@ -211,4 +209,192 @@ public class PoliciesControllerTests : IClassFixture, IAs } } + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minLength", content); // Verify field name is in error message + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minComplexity", "not a number" }, // Wrong type - should be int + { "minLength", 12 } + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("minComplexity", content); // Verify field name is in error message + } + + [Fact] + public async Task PutVNext_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task PutVNext_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.SingleOrg; + var request = new PolicyRequestModel + { + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task PutVNext_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.TwoFactorAuthentication; + var request = new SavePolicyRequest + { + Policy = new PolicyRequestModel + { + Enabled = true, + Data = null + }, + Metadata = null + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}/vnext", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs index 11c60ad57c..2eeba5d47e 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs @@ -64,6 +64,17 @@ public class MembersControllerTests : IClassFixture, IAsy var (userEmail4, orgUser4) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync(_factory, _organization.Id, OrganizationUserType.Admin); + var collection1 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 1", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = false, Manage = true }, + new CollectionAccessSelection { Id = orgUser3.Id, ReadOnly = true, HidePasswords = false, Manage = false } + ]); + + var collection2 = await OrganizationTestHelpers.CreateCollectionAsync(_factory, _organization.Id, "Test Collection 2", users: + [ + new CollectionAccessSelection { Id = orgUser1.Id, ReadOnly = false, HidePasswords = true, Manage = false } + ]); + var response = await _client.GetAsync($"/public/members"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); var result = await response.Content.ReadFromJsonAsync>(); @@ -71,23 +82,47 @@ public class MembersControllerTests : IClassFixture, IAsy Assert.Equal(5, result.Data.Count()); // The owner - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner)); + var ownerResult = result.Data.SingleOrDefault(m => m.Email == _ownerEmail && m.Type == OrganizationUserType.Owner); + Assert.NotNull(ownerResult); + Assert.Empty(ownerResult.Collections); - // The custom user + // The custom user with collections var user1Result = result.Data.Single(m => m.Email == userEmail1); Assert.Equal(OrganizationUserType.Custom, user1Result.Type); AssertHelper.AssertPropertyEqual( new PermissionsModel { AccessImportExport = true, ManagePolicies = true, AccessReports = true }, user1Result.Permissions); + // Verify collections + Assert.NotNull(user1Result.Collections); + Assert.Equal(2, user1Result.Collections.Count()); + var user1Collection1 = user1Result.Collections.Single(c => c.Id == collection1.Id); + Assert.False(user1Collection1.ReadOnly); + Assert.False(user1Collection1.HidePasswords); + Assert.True(user1Collection1.Manage); + var user1Collection2 = user1Result.Collections.Single(c => c.Id == collection2.Id); + Assert.False(user1Collection2.ReadOnly); + Assert.True(user1Collection2.HidePasswords); + Assert.False(user1Collection2.Manage); - // Everyone else - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail2 && m.Type == OrganizationUserType.Owner)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail3 && m.Type == OrganizationUserType.User)); - Assert.NotNull(result.Data.SingleOrDefault(m => - m.Email == userEmail4 && m.Type == OrganizationUserType.Admin)); + // The other owner + var user2Result = result.Data.SingleOrDefault(m => m.Email == userEmail2 && m.Type == OrganizationUserType.Owner); + Assert.NotNull(user2Result); + Assert.Empty(user2Result.Collections); + + // The user with one collection + var user3Result = result.Data.SingleOrDefault(m => m.Email == userEmail3 && m.Type == OrganizationUserType.User); + Assert.NotNull(user3Result); + Assert.NotNull(user3Result.Collections); + Assert.Single(user3Result.Collections); + var user3Collection1 = user3Result.Collections.Single(c => c.Id == collection1.Id); + Assert.True(user3Collection1.ReadOnly); + Assert.False(user3Collection1.HidePasswords); + Assert.False(user3Collection1.Manage); + + // The admin with no collections + var user4Result = result.Data.SingleOrDefault(m => m.Email == userEmail4 && m.Type == OrganizationUserType.Admin); + Assert.NotNull(user4Result); + Assert.Empty(user4Result.Collections); } [Fact] diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index f034426f98..0b5ab660b9 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -160,4 +160,86 @@ public class PoliciesControllerTests : IClassFixture, IAs Assert.Equal(15, data.MinLength); Assert.Equal(true, data.RequireUpper); } + + [Fact] + public async Task Put_MasterPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", "not a number" }, // Wrong type - should be int + { "requireUpper", true } + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_SendOptionsPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.SendOptions; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "disableHideEmail", "not a boolean" } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_ResetPasswordPolicy_InvalidDataType_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.ResetPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "autoEnrollEnabled", 123 } // Wrong type - should be bool + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_PolicyWithNullData_Success() + { + // Arrange + var policyType = PolicyType.DisableSend; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = null + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs index 4e5a6850e7..09ec5b010f 100644 --- a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs +++ b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs @@ -1,31 +1,81 @@ -using System.Net.Http.Headers; +using System.Net; +using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Api.KeyManagement.Models.Requests; using Bit.Api.Models.Response; +using Bit.Core; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Platform.Push; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Identity; +using NSubstitute; using Xunit; namespace Bit.Api.IntegrationTest.Controllers; -public class AccountsControllerTest : IClassFixture +public class AccountsControllerTest : IClassFixture, IAsyncLifetime { - private readonly ApiApplicationFactory _factory; + private static readonly string _masterKeyWrappedUserKey = + "2.AOs41Hd8OQiCPXjyJKCiDA==|O6OHgt2U2hJGBSNGnimJmg==|iD33s8B69C8JhYYhSa4V1tArjvLr8eEaGqOV7BRo5Jk="; - public AccountsControllerTest(ApiApplicationFactory factory) => _factory = factory; + private static readonly string _masterPasswordHash = "master_password_hash"; + private static readonly string _newMasterPasswordHash = "new_master_password_hash"; + + private static readonly KdfRequestModel _defaultKdfRequest = + new() { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_000 }; + + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + private readonly IUserRepository _userRepository; + private readonly IPushNotificationService _pushNotificationService; + private readonly IFeatureService _featureService; + private readonly IPasswordHasher _passwordHasher; + + private string _ownerEmail = null!; + + public AccountsControllerTest(ApiApplicationFactory factory) + { + _factory = factory; + _factory.SubstituteService(_ => { }); + _factory.SubstituteService(_ => { }); + _client = factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + _userRepository = _factory.GetService(); + _pushNotificationService = _factory.GetService(); + _featureService = _factory.GetService(); + _passwordHasher = _factory.GetService>(); + } + + public async Task InitializeAsync() + { + _ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } [Fact] public async Task GetAccountsProfile_success() { - var tokens = await _factory.LoginWithNewAccount(); - var client = _factory.CreateClient(); + await _loginHelper.LoginAsync(_ownerEmail); using var message = new HttpRequestMessage(HttpMethod.Get, "/accounts/profile"); - message.Headers.Authorization = new AuthenticationHeaderValue("Bearer", tokens.Token); - var response = await client.SendAsync(message); + var response = await _client.SendAsync(message); response.EnsureSuccessStatusCode(); var content = await response.Content.ReadFromJsonAsync(); Assert.NotNull(content); - Assert.Equal("integration-test@bitwarden.com", content.Email); + Assert.Equal(_ownerEmail, content.Email); Assert.NotNull(content.Name); Assert.True(content.EmailVerified); Assert.False(content.Premium); @@ -35,4 +85,354 @@ public class AccountsControllerTest : IClassFixture Assert.NotNull(content.PrivateKey); Assert.NotNull(content.SecurityStamp); } + + [Theory] + [BitAutoData(KdfType.PBKDF2_SHA256, 600001, null, null)] + [BitAutoData(KdfType.Argon2id, 4, 65, 5)] + public async Task PostKdf_ValidRequestLogoutOnKdfChangeFeatureFlagOff_SuccessLogout(KdfType kdf, + int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + var userBeforeKdfChange = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(userBeforeKdfChange); + + _featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange).Returns(false); + + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism, + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Validate that the user fields were updated correctly + var user = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(user); + Assert.Equal(kdfRequest.KdfType, user.Kdf); + Assert.Equal(kdfRequest.Iterations, user.KdfIterations); + Assert.Equal(kdfRequest.Memory, user.KdfMemory); + Assert.Equal(kdfRequest.Parallelism, user.KdfParallelism); + Assert.Equal(_masterKeyWrappedUserKey, user.Key); + Assert.NotNull(user.LastKdfChangeDate); + Assert.True(user.LastKdfChangeDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.RevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.AccountRevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.NotEqual(userBeforeKdfChange.SecurityStamp, user.SecurityStamp); + Assert.Equal(PasswordVerificationResult.Success, + _passwordHasher.VerifyHashedPassword(user, user.MasterPassword!, _newMasterPasswordHash)); + + // Validate push notification + await _pushNotificationService.Received(1).PushLogOutAsync(user.Id); + } + + [Theory] + [BitAutoData(KdfType.PBKDF2_SHA256, 600001, null, null)] + [BitAutoData(KdfType.Argon2id, 4, 65, 5)] + public async Task PostKdf_ValidRequestLogoutOnKdfChangeFeatureFlagOn_SuccessSyncAndLogoutWithReason(KdfType kdf, + int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + var userBeforeKdfChange = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(userBeforeKdfChange); + + _featureService.IsEnabled(FeatureFlagKeys.NoLogoutOnKdfChange).Returns(true); + + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism, + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Validate that the user fields were updated correctly + var user = await _userRepository.GetByEmailAsync(_ownerEmail); + Assert.NotNull(user); + Assert.Equal(kdfRequest.KdfType, user.Kdf); + Assert.Equal(kdfRequest.Iterations, user.KdfIterations); + Assert.Equal(kdfRequest.Memory, user.KdfMemory); + Assert.Equal(kdfRequest.Parallelism, user.KdfParallelism); + Assert.Equal(_masterKeyWrappedUserKey, user.Key); + Assert.NotNull(user.LastKdfChangeDate); + Assert.True(user.LastKdfChangeDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.RevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.True(user.AccountRevisionDate > DateTime.UtcNow.AddMinutes(-1)); + Assert.Equal(userBeforeKdfChange.SecurityStamp, user.SecurityStamp); + Assert.Equal(PasswordVerificationResult.Success, + _passwordHasher.VerifyHashedPassword(user, user.MasterPassword!, _newMasterPasswordHash)); + + // Validate push notification + await _pushNotificationService.Received(1) + .PushLogOutAsync(user.Id, false, PushNotificationLogOutReason.KdfChange); + await _pushNotificationService.Received(1).PushSyncSettingsAsync(user.Id); + } + + [Fact] + public async Task PostKdf_Unauthorized_ReturnsUnauthorized() + { + // Don't call LoginAsync to test unauthorized access + + var response = await PostKdfWithKdfRequestAsync(_defaultKdfRequest); + + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + } + + [Theory] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task PostKdf_AuthenticationDataOrUnlockDataNull_BadRequest(bool authenticationDataNull, + bool unlockDataNull) + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = authenticationDataNull + ? null + : new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = unlockDataNull + ? null + : new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("AuthenticationData and UnlockData must be provided.", content); + } + + [Fact] + public async Task PostKdf_InvalidMasterPasswordHash_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = "wrong-master-password-hash", + NewMasterPasswordHash = _newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationData, + UnlockData = unlockData + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + var response = await _client.SendAsync(message); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Incorrect password", content); + } + + [Fact] + public async Task PostKdf_ChangedSaltInAuthenticationData_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = "wrong-salt@bitwarden.com" + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Invalid master password salt.", content); + } + + [Fact] + public async Task PostKdf_ChangedSaltInUnlockData_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = "wrong-salt@bitwarden.com" + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Invalid master password salt.", content); + } + + [Fact] + public async Task PostKdf_KdfNotMatching_BadRequest() + { + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_000 }, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = 600_001 }, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var response = await PostKdfAsync(authenticationData, unlockData); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("KDF settings must be equal for authentication and unlock.", content); + } + + [Theory] + [InlineData(KdfType.PBKDF2_SHA256, 1, null, null)] + [InlineData(KdfType.Argon2id, 4, null, 5)] + [InlineData(KdfType.Argon2id, 4, 65, null)] + public async Task PostKdf_InvalidKdf_BadRequest(KdfType kdf, int kdfIterations, int? kdfMemory, int? kdfParallelism) + { + await _loginHelper.LoginAsync(_ownerEmail); + + var kdfRequest = new KdfRequestModel + { + KdfType = kdf, + Iterations = kdfIterations, + Memory = kdfMemory, + Parallelism = kdfParallelism + }; + + var response = await PostKdfWithKdfRequestAsync(kdfRequest); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("KDF settings are invalid", content); + } + + [Fact] + public async Task PostKdf_InvalidNewMasterPassword_BadRequest() + { + var newMasterPasswordHash = "too-short"; + + await _loginHelper.LoginAsync(_ownerEmail); + + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterPasswordAuthenticationHash = newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = _defaultKdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = _masterPasswordHash, + NewMasterPasswordHash = newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationData, + UnlockData = unlockData + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + var response = await _client.SendAsync(message); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var content = await response.Content.ReadAsStringAsync(); + Assert.Contains("Passwords must be at least", content); + } + + private async Task PostKdfWithKdfRequestAsync(KdfRequestModel kdfRequest) + { + var authenticationData = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = kdfRequest, + MasterPasswordAuthenticationHash = _newMasterPasswordHash, + Salt = _ownerEmail + }; + + var unlockData = new MasterPasswordUnlockDataRequestModel + { + Kdf = kdfRequest, + MasterKeyWrappedUserKey = _masterKeyWrappedUserKey, + Salt = _ownerEmail + }; + + return await PostKdfAsync(authenticationData, unlockData); + } + + private async Task PostKdfAsync( + MasterPasswordAuthenticationDataRequestModel? authenticationDataRequest, + MasterPasswordUnlockDataRequestModel? unlockDataRequest) + { + var requestModel = new PasswordRequestModel + { + MasterPasswordHash = _masterPasswordHash, + NewMasterPasswordHash = _newMasterPasswordHash, + Key = _masterKeyWrappedUserKey, + AuthenticationData = authenticationDataRequest, + UnlockData = unlockDataRequest + }; + + using var message = new HttpRequestMessage(HttpMethod.Post, "/accounts/kdf"); + message.Content = JsonContent.Create(requestModel); + return await _client.SendAsync(message); + } } diff --git a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs index 3cd73c4b1c..c23ebff736 100644 --- a/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs +++ b/test/Api.IntegrationTest/Helpers/OrganizationTestHelpers.cs @@ -151,6 +151,28 @@ public static class OrganizationTestHelpers return group; } + /// + /// Creates a collection with optional user and group associations. + /// + public static async Task CreateCollectionAsync( + ApiApplicationFactory factory, + Guid organizationId, + string name, + IEnumerable? users = null, + IEnumerable? groups = null) + { + var collectionRepository = factory.GetService(); + var collection = new Collection + { + OrganizationId = organizationId, + Name = name, + Type = CollectionType.SharedCollection + }; + + await collectionRepository.CreateAsync(collection, groups, users); + return collection; + } + /// /// Enables the Organization Data Ownership policy for the specified organization. /// diff --git a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs index bf27d7f0d1..1630bc0dc0 100644 --- a/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs +++ b/test/Api.IntegrationTest/KeyManagement/Controllers/AccountsKeyManagementControllerTests.cs @@ -12,6 +12,10 @@ using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Entities; +using Bit.Core.KeyManagement.Enums; +using Bit.Core.KeyManagement.Models.Api.Request; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.Repositories; using Bit.Core.Vault.Enums; using Bit.Test.Common.AutoFixture.Attributes; @@ -24,6 +28,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture _passwordHasher; private readonly IOrganizationRepository _organizationRepository; + private readonly IUserSignatureKeyPairRepository _userSignatureKeyPairRepository; private string _ownerEmail = null!; public AccountsKeyManagementControllerTests(ApiApplicationFactory factory) @@ -49,6 +55,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture(); _passwordHasher = _factory.GetService>(); _organizationRepository = _factory.GetService(); + _userSignatureKeyPairRepository = _factory.GetService(); } public async Task InitializeAsync() @@ -200,6 +207,7 @@ public class AccountsKeyManagementControllerTests : IClassFixture sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsNotMemberOrProvider_NotAuthorized( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, null); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ can recover the account for a ___ + public static IEnumerable AuthorizedRoleCombinations => new object[][] + { + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Owner }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(AuthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_RecoverEqualOrLesserRoles_TargetUserNotProvider_Authorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + // Pairing of CurrentContextOrganization (current user permissions) and target user role + // Read this as: a ___ cannot recover the account for a ___ + public static IEnumerable UnauthorizedRoleCombinations => new object[][] + { + // These roles should fail because you cannot recover a greater role + [new CurrentContextOrganization { Type = OrganizationUserType.Admin }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true}}, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom, Permissions = new Permissions { ManageResetPassword = true} }, OrganizationUserType.Admin], + + // These roles are never authorized to recover any account + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.User }, OrganizationUserType.User], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Owner], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Admin], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.Custom], + [new CurrentContextOrganization { Type = OrganizationUserType.Custom }, OrganizationUserType.User], + }; + + [Theory, BitMemberAutoData(nameof(UnauthorizedRoleCombinations))] + public async Task AuthorizeMemberAsync_InvalidRoles_TargetUserNotProvider_Unauthorized( + CurrentContextOrganization currentContextOrganization, + OrganizationUserType targetOrganizationUserType, + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.Type = targetOrganizationUserType; + currentContextOrganization.Id = targetOrganizationUser.OrganizationId; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockOrganizationClaims(sutProvider, claimsPrincipal, targetOrganizationUser, currentContextOrganization); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.FailureReason); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_TargetUserIdIsNull_DoesNotBlock( + SutProvider sutProvider, + OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal) + { + // Arrange + targetOrganizationUser.UserId = null; + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + // This should shortcut the provider escalation check + await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() + .GetManyByUserAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserIsMemberOfAllTargetUserProviders_DoesNotBlock( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsProvider(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(true); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + Assert.True(context.HasSucceeded); + } + + [Theory, BitAutoData] + public async Task HandleRequirementAsync_CurrentUserMissingProviderMembership_Blocks( + SutProvider sutProvider, + [OrganizationUser] OrganizationUser targetOrganizationUser, + ClaimsPrincipal claimsPrincipal, + Guid providerId1, + Guid providerId2) + { + // Arrange + var targetUserProviders = new List + { + new() { ProviderId = providerId1, UserId = targetOrganizationUser.UserId }, + new() { ProviderId = providerId2, UserId = targetOrganizationUser.UserId } + }; + + var context = new AuthorizationHandlerContext( + [new RecoverAccountAuthorizationRequirement()], + claimsPrincipal, + targetOrganizationUser); + + MockCurrentUserIsOwner(sutProvider, claimsPrincipal, targetOrganizationUser); + + sutProvider.GetDependency() + .GetManyByUserAsync(targetOrganizationUser.UserId!.Value) + .Returns(targetUserProviders); + + sutProvider.GetDependency() + .ProviderUser(providerId1) + .Returns(true); + + // Not a member of this provider + sutProvider.GetDependency() + .ProviderUser(providerId2) + .Returns(false); + + // Act + await sutProvider.Sut.HandleAsync(context); + + // Assert + AssertFailed(context, RecoverAccountAuthorizationHandler.ProviderFailureReason); + } + + private static void MockOrganizationClaims(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser, + CurrentContextOrganization? currentContextOrganization) + { + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void MockCurrentUserIsProvider(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + sutProvider.GetDependency() + .IsProviderUserForOrganization(currentUser, targetOrganizationUser.OrganizationId) + .Returns(true); + } + + private static void MockCurrentUserIsOwner(SutProvider sutProvider, + ClaimsPrincipal currentUser, OrganizationUser targetOrganizationUser) + { + var currentContextOrganization = new CurrentContextOrganization + { + Id = targetOrganizationUser.OrganizationId, + Type = OrganizationUserType.Owner + }; + + sutProvider.GetDependency() + .GetOrganizationClaims(currentUser, targetOrganizationUser.OrganizationId) + .Returns(currentContextOrganization); + } + + private static void AssertFailed(AuthorizationHandlerContext context, string expectedMessage) + { + Assert.True(context.HasFailed); + var failureReason = Assert.Single(context.FailureReasons); + Assert.Equal(expectedMessage, failureReason.Message); + } +} diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs index 1dd0e86f39..335859e0c4 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationControllerTests.cs @@ -133,6 +133,29 @@ public class OrganizationIntegrationControllerTests .DeleteAsync(organizationIntegration); } + [Theory, BitAutoData] + public async Task PostDeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration organizationIntegration) + { + organizationIntegration.OrganizationId = organizationId; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(organizationIntegration); + + await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(organizationIntegration.Id); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationIntegration); + } + [Theory, BitAutoData] public async Task DeleteAsync_IntegrationDoesNotBelongToOrganization_ThrowsNotFound( SutProvider sutProvider, diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs index 4ccfa70308..9ab626d3f0 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationIntegrationsConfigurationControllerTests.cs @@ -51,6 +51,36 @@ public class OrganizationIntegrationsConfigurationControllerTests .DeleteAsync(organizationIntegrationConfiguration); } + [Theory, BitAutoData] + public async Task PostDeleteAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration organizationIntegration, + OrganizationIntegrationConfiguration organizationIntegrationConfiguration) + { + organizationIntegration.OrganizationId = organizationId; + organizationIntegrationConfiguration.OrganizationIntegrationId = organizationIntegration.Id; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(organizationIntegration); + sutProvider.GetDependency() + .GetByIdAsync(Arg.Any()) + .Returns(organizationIntegrationConfiguration); + + await sutProvider.Sut.PostDeleteAsync(organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id); + + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(organizationIntegration.Id); + await sutProvider.GetDependency().Received(1) + .GetByIdAsync(organizationIntegrationConfiguration.Id); + await sutProvider.GetDependency().Received(1) + .DeleteAsync(organizationIntegrationConfiguration); + } + [Theory, BitAutoData] public async Task DeleteAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( SutProvider sutProvider, @@ -199,27 +229,6 @@ public class OrganizationIntegrationsConfigurationControllerTests .GetManyByIntegrationAsync(organizationIntegration.Id); } - // [Theory, BitAutoData] - // public async Task GetAsync_IntegrationConfigurationDoesNotExist_ThrowsNotFound( - // SutProvider sutProvider, - // Guid organizationId, - // OrganizationIntegration organizationIntegration) - // { - // organizationIntegration.OrganizationId = organizationId; - // sutProvider.Sut.Url = Substitute.For(); - // sutProvider.GetDependency() - // .OrganizationOwner(organizationId) - // .Returns(true); - // sutProvider.GetDependency() - // .GetByIdAsync(Arg.Any()) - // .Returns(organizationIntegration); - // sutProvider.GetDependency() - // .GetByIdAsync(Arg.Any()) - // .ReturnsNull(); - // - // await Assert.ThrowsAsync(async () => await sutProvider.Sut.GetAsync(organizationId, Guid.Empty, Guid.Empty)); - // } - // [Theory, BitAutoData] public async Task GetAsync_IntegrationDoesNotExist_ThrowsNotFound( SutProvider sutProvider, @@ -293,15 +302,16 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .CreateAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); await sutProvider.GetDependency().Received(1) .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(createResponse); + Assert.Equal(expected.Id, createResponse.Id); + Assert.Equal(expected.Configuration, createResponse.Configuration); + Assert.Equal(expected.EventType, createResponse.EventType); + Assert.Equal(expected.Filters, createResponse.Filters); + Assert.Equal(expected.Template, createResponse.Template); } [Theory, BitAutoData] @@ -331,15 +341,16 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .CreateAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); await sutProvider.GetDependency().Received(1) .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(createResponse); + Assert.Equal(expected.Id, createResponse.Id); + Assert.Equal(expected.Configuration, createResponse.Configuration); + Assert.Equal(expected.EventType, createResponse.EventType); + Assert.Equal(expected.Filters, createResponse.Filters); + Assert.Equal(expected.Template, createResponse.Template); } [Theory, BitAutoData] @@ -369,15 +380,16 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .CreateAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); + var createResponse = await sutProvider.Sut.CreateAsync(organizationId, organizationIntegration.Id, model); await sutProvider.GetDependency().Received(1) .CreateAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(createResponse); + Assert.Equal(expected.Id, createResponse.Id); + Assert.Equal(expected.Configuration, createResponse.Configuration); + Assert.Equal(expected.EventType, createResponse.EventType); + Assert.Equal(expected.Filters, createResponse.Filters); + Assert.Equal(expected.Template, createResponse.Template); } [Theory, BitAutoData] @@ -575,7 +587,7 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( + var updateResponse = await sutProvider.Sut.UpdateAsync( organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id, @@ -583,11 +595,12 @@ public class OrganizationIntegrationsConfigurationControllerTests await sutProvider.GetDependency().Received(1) .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(updateResponse); + Assert.Equal(expected.Id, updateResponse.Id); + Assert.Equal(expected.Configuration, updateResponse.Configuration); + Assert.Equal(expected.EventType, updateResponse.EventType); + Assert.Equal(expected.Filters, updateResponse.Filters); + Assert.Equal(expected.Template, updateResponse.Template); } @@ -619,7 +632,7 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( + var updateResponse = await sutProvider.Sut.UpdateAsync( organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id, @@ -627,11 +640,12 @@ public class OrganizationIntegrationsConfigurationControllerTests await sutProvider.GetDependency().Received(1) .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(updateResponse); + Assert.Equal(expected.Id, updateResponse.Id); + Assert.Equal(expected.Configuration, updateResponse.Configuration); + Assert.Equal(expected.EventType, updateResponse.EventType); + Assert.Equal(expected.Filters, updateResponse.Filters); + Assert.Equal(expected.Template, updateResponse.Template); } [Theory, BitAutoData] @@ -662,7 +676,7 @@ public class OrganizationIntegrationsConfigurationControllerTests sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) .Returns(organizationIntegrationConfiguration); - var requestAction = await sutProvider.Sut.UpdateAsync( + var updateResponse = await sutProvider.Sut.UpdateAsync( organizationId, organizationIntegration.Id, organizationIntegrationConfiguration.Id, @@ -670,11 +684,12 @@ public class OrganizationIntegrationsConfigurationControllerTests await sutProvider.GetDependency().Received(1) .ReplaceAsync(Arg.Any()); - Assert.IsType(requestAction); - Assert.Equal(expected.Id, requestAction.Id); - Assert.Equal(expected.Configuration, requestAction.Configuration); - Assert.Equal(expected.EventType, requestAction.EventType); - Assert.Equal(expected.Template, requestAction.Template); + Assert.IsType(updateResponse); + Assert.Equal(expected.Id, updateResponse.Id); + Assert.Equal(expected.Configuration, updateResponse.Configuration); + Assert.Equal(expected.EventType, updateResponse.EventType); + Assert.Equal(expected.Filters, updateResponse.Filters); + Assert.Equal(expected.Template, updateResponse.Template); } [Theory, BitAutoData] diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs index e5aa03f067..5875cda05a 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs @@ -1,11 +1,14 @@ using System.Security.Claims; +using Bit.Api.AdminConsole.Authorization; using Bit.Api.AdminConsole.Controllers; using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Api.Models.Request.Organizations; using Bit.Api.Vault.AuthorizationHandlers.Collections; using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; @@ -16,6 +19,7 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Api; using Bit.Core.Models.Business; using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations; @@ -30,6 +34,7 @@ using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; +using Microsoft.AspNetCore.Mvc.ModelBinding; using NSubstitute; using Xunit; @@ -440,4 +445,153 @@ public class OrganizationUsersControllerTests Assert.Equal("Master Password reset is required, but not provided.", exception.Message); } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_CallsLegacyPath( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); + sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1) + .AdminResetPasswordAsync(OrganizationUserType.Owner, orgId, orgUserId, model.NewMasterPasswordHash, model.Key); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_WhenOrgUserTypeIsNull_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(false); + sutProvider.GetDependency().Organizations.Returns(new List()); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagDisabled_WhenAdminResetPasswordFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(false); + sutProvider.GetDependency().OrganizationOwner(orgId).Returns(true); + sutProvider.GetDependency().AdminResetPasswordAsync(Arg.Any(), orgId, orgUserId, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error 1" })); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationUserNotFound_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, + SutProvider sutProvider) + { + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns((OrganizationUser)null); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenOrganizationIdMismatch_ReturnsNotFound( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = Guid.NewGuid(); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenAuthorizationFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Failed()); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountSucceeds_ReturnsOk( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Success); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType(result); + await sutProvider.GetDependency().Received(1) + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key); + } + + [Theory] + [BitAutoData] + public async Task PutResetPassword_WithFeatureFlagEnabled_WhenRecoverAccountFails_ReturnsBadRequest( + Guid orgId, Guid orgUserId, OrganizationUserResetPasswordRequestModel model, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountRecoveryCommand).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + sutProvider.GetDependency() + .AuthorizeAsync( + Arg.Any(), + organizationUser, + Arg.Is>(x => x.SingleOrDefault() is RecoverAccountAuthorizationRequirement)) + .Returns(AuthorizationResult.Success()); + sutProvider.GetDependency() + .RecoverAccountAsync(orgId, organizationUser, model.NewMasterPasswordHash, model.Key) + .Returns(Microsoft.AspNetCore.Identity.IdentityResult.Failed(new Microsoft.AspNetCore.Identity.IdentityError { Description = "Error message" })); + + var result = await sutProvider.Sut.PutResetPassword(orgId, orgUserId, model); + + Assert.IsType>(result); + } } diff --git a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs index 9bbc8a77c0..c079445559 100644 --- a/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/SlackIntegrationControllerTests.cs @@ -1,12 +1,18 @@ -using Bit.Api.AdminConsole.Controllers; +#nullable enable + +using Bit.Api.AdminConsole.Controllers; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Routing; +using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; @@ -16,98 +22,350 @@ namespace Bit.Api.Test.AdminConsole.Controllers; [SutProviderCustomize] public class SlackIntegrationControllerTests { + private const string _slackToken = "xoxb-test-token"; + private const string _validSlackCode = "A_test_code"; + [Theory, BitAutoData] - public async Task CreateAsync_AllParamsProvided_Succeeds(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + OrganizationIntegration integration) { - var token = "xoxb-test-token"; + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) - .Returns(token); + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); sutProvider.GetDependency() - .CreateAsync(Arg.Any()) - .Returns(callInfo => callInfo.Arg()); - var requestAction = await sutProvider.Sut.CreateAsync(organizationId, "A_test_code"); + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + var requestAction = await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString()); await sutProvider.GetDependency().Received(1) - .CreateAsync(Arg.Any()); + .UpsertAsync(Arg.Any()); Assert.IsType(requestAction); } [Theory, BitAutoData] - public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, string.Empty)); + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(string.Empty, state.ToString())); } [Theory, BitAutoData] - public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(true); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_SlackServiceReturnsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Slack; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) .Returns(string.Empty); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code")); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); } [Theory, BitAutoData] - public async Task CreateAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_StateEmpty_ThrowsNotFound( + SutProvider sutProvider) { - var token = "xoxb-test-token"; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency() - .OrganizationOwner(organizationId) - .Returns(false); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); sutProvider.GetDependency() - .ObtainTokenViaOAuth(Arg.Any(), Arg.Any()) - .Returns(token); + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(organizationId, "A_test_code")); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, string.Empty)); } [Theory, BitAutoData] - public async Task RedirectAsync_Success(SutProvider sutProvider, Guid organizationId) + public async Task CreateAsync_StateExpired_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) { - var expectedUrl = $"https://localhost/{organizationId}"; + var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc)); + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + timeProvider.Advance(TimeSpan.FromMinutes(30)); + + sutProvider.SetDependency(timeProvider); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonexistentIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasWrongOrganizationHash_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration wrongOrgIntegration) + { + wrongOrgIntegration.Id = integration.Id; + wrongOrgIntegration.Type = IntegrationType.Slack; + wrongOrgIntegration.Configuration = null; sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(expectedUrl); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(wrongOrgIntegration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonEmptyIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Slack; + integration.Configuration = "{}"; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonSlackIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Hec; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validSlackCode, Arg.Any()) + .Returns(_slackToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validSlackCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_Success( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(integration.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + var requestAction = await sutProvider.Sut.RedirectAsync(integration.OrganizationId); + + Assert.IsType(requestAction); + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Any()); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithNullConfig_Success( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Slack; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); - sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); var requestAction = await sutProvider.Sut.RedirectAsync(organizationId); - var redirectResult = Assert.IsType(requestAction); - Assert.Equal(expectedUrl, redirectResult.Url); + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + Assert.IsType(requestAction); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); } [Theory, BitAutoData] - public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) + public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) { + integration.OrganizationId = organizationId; + integration.Configuration = "{}"; + integration.Type = IntegrationType.Slack; + var expectedUrl = "https://localhost/"; + sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(string.Empty); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns(expectedUrl); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_CallbackUrlReturnsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns((string?)null); sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); + .OrganizationOwner(organizationId) + .Returns(true); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_SlackServiceReturnsEmpty_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "SlackIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(string.Empty); await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); } @@ -116,14 +374,9 @@ public class SlackIntegrationControllerTests public async Task RedirectAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, Guid organizationId) { - sutProvider.Sut.Url = Substitute.For(); - sutProvider.GetDependency().GetRedirectUrl(Arg.Any()).Returns(string.Empty); sutProvider.GetDependency() .OrganizationOwner(organizationId) .Returns(false); - sutProvider.GetDependency() - .HttpContext.Request.Scheme - .Returns("https"); await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); } diff --git a/test/Api.Test/AdminConsole/Controllers/TeamsIntegrationControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/TeamsIntegrationControllerTests.cs new file mode 100644 index 0000000000..3302a87372 --- /dev/null +++ b/test/Api.Test/AdminConsole/Controllers/TeamsIntegrationControllerTests.cs @@ -0,0 +1,436 @@ +#nullable enable + +using Bit.Api.AdminConsole.Controllers; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Context; +using Bit.Core.Enums; +using Bit.Core.Exceptions; +using Bit.Core.Models.Teams; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Routing; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Integration.AspNet.Core; +using Microsoft.Extensions.Time.Testing; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Controllers; + +[ControllerCustomize(typeof(TeamsIntegrationController))] +[SutProviderCustomize] +public class TeamsIntegrationControllerTests +{ + private const string _teamsToken = "test-token"; + private const string _validTeamsCode = "A_test_code"; + + [Theory, BitAutoData] + public async Task CreateAsync_AllParamsProvided_Succeeds( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetJoinedTeamsAsync(_teamsToken) + .Returns([ + new TeamInfo() { DisplayName = "Test Team", Id = Guid.NewGuid().ToString(), TenantId = Guid.NewGuid().ToString() } + ]); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + var requestAction = await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString()); + + await sutProvider.GetDependency().Received(1) + .UpsertAsync(Arg.Any()); + Assert.IsType(requestAction); + } + + [Theory, BitAutoData] + public async Task CreateAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_CodeIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.CreateAsync(string.Empty, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_NoTeamsFound_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetJoinedTeamsAsync(_teamsToken) + .Returns([]); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_TeamsServiceReturnsEmptyToken_ThrowsBadRequest( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(string.Empty); + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateEmpty_ThrowsNotFound( + SutProvider sutProvider) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, string.Empty)); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateExpired_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + var timeProvider = new FakeTimeProvider(new DateTime(2024, 4, 3, 2, 1, 0, DateTimeKind.Utc)); + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + var state = IntegrationOAuthState.FromIntegration(integration, timeProvider); + timeProvider.Advance(TimeSpan.FromMinutes(30)); + + sutProvider.SetDependency(timeProvider); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonexistentIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasWrongOrganizationHash_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration, + OrganizationIntegration wrongOrgIntegration) + { + wrongOrgIntegration.Id = integration.Id; + wrongOrgIntegration.Type = IntegrationType.Teams; + wrongOrgIntegration.Configuration = null; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(wrongOrgIntegration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonEmptyIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Teams; + integration.Configuration = "{}"; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task CreateAsync_StateHasNonTeamsIntegration_ThrowsNotFound( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Type = IntegrationType.Hec; + integration.Configuration = null; + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns("https://localhost"); + sutProvider.GetDependency() + .ObtainTokenViaOAuth(_validTeamsCode, Arg.Any()) + .Returns(_teamsToken); + sutProvider.GetDependency() + .GetByIdAsync(integration.Id) + .Returns(integration); + + var state = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + await Assert.ThrowsAsync(async () => await sutProvider.Sut.CreateAsync(_validTeamsCode, state.ToString())); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_Success( + SutProvider sutProvider, + OrganizationIntegration integration) + { + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(integration.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(integration.OrganizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + var requestAction = await sutProvider.Sut.RedirectAsync(integration.OrganizationId); + + Assert.IsType(requestAction); + await sutProvider.GetDependency().Received(1) + .CreateAsync(Arg.Any()); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithNullConfig_Success( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Teams; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + var requestAction = await sutProvider.Sut.RedirectAsync(organizationId); + + var expectedState = IntegrationOAuthState.FromIntegration(integration, sutProvider.GetDependency()); + + Assert.IsType(requestAction); + sutProvider.GetDependency().Received(1).GetRedirectUrl(Arg.Any(), expectedState.ToString()); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_CallbackUrlIsEmpty_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + integration.Type = IntegrationType.Teams; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns((string?)null); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_IntegrationAlreadyExistsWithConfig_ThrowsBadRequest( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = "{}"; + integration.Type = IntegrationType.Teams; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([integration]); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(expectedUrl); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_TeamsServiceReturnsEmpty_ThrowsNotFound( + SutProvider sutProvider, + Guid organizationId, + OrganizationIntegration integration) + { + integration.OrganizationId = organizationId; + integration.Configuration = null; + var expectedUrl = "https://localhost/"; + + sutProvider.Sut.Url = Substitute.For(); + sutProvider.Sut.Url + .RouteUrl(Arg.Is(c => c.RouteName == "TeamsIntegration_Create")) + .Returns(expectedUrl); + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(true); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(organizationId) + .Returns([]); + sutProvider.GetDependency() + .CreateAsync(Arg.Any()) + .Returns(integration); + sutProvider.GetDependency().GetRedirectUrl(Arg.Any(), Arg.Any()).Returns(string.Empty); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task RedirectAsync_UserIsNotOrganizationAdmin_ThrowsNotFound(SutProvider sutProvider, + Guid organizationId) + { + sutProvider.GetDependency() + .OrganizationOwner(organizationId) + .Returns(false); + + await Assert.ThrowsAsync(async () => await sutProvider.Sut.RedirectAsync(organizationId)); + } + + [Theory, BitAutoData] + public async Task IncomingPostAsync_ForwardsToBot(SutProvider sutProvider) + { + var adapter = sutProvider.GetDependency(); + var bot = sutProvider.GetDependency(); + + await sutProvider.Sut.IncomingPostAsync(); + await adapter.Received(1).ProcessAsync(Arg.Any(), Arg.Any(), bot); + } +} diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs index 74fe75a9d7..8a75db9da8 100644 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs +++ b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationConfigurationRequestModelTests.cs @@ -39,7 +39,7 @@ public class OrganizationIntegrationConfigurationRequestModelTests [Theory] [InlineData(data: "")] [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullHecConfiguration_ReturnsFalse(string? config) + public void IsValidForType_EmptyNonNullConfiguration_ReturnsFalse(string? config) { var model = new OrganizationIntegrationConfigurationRequestModel { @@ -48,10 +48,12 @@ public class OrganizationIntegrationConfigurationRequestModelTests }; Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); + Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); + Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); } [Fact] - public void IsValidForType_NullHecConfiguration_ReturnsTrue() + public void IsValidForType_NullConfiguration_ReturnsTrue() { var model = new OrganizationIntegrationConfigurationRequestModel { @@ -60,32 +62,8 @@ public class OrganizationIntegrationConfigurationRequestModelTests }; Assert.True(condition: model.IsValidForType(IntegrationType.Hec)); - } - - [Theory] - [InlineData(data: "")] - [InlineData(data: " ")] - public void IsValidForType_EmptyNonNullDatadogConfiguration_ReturnsFalse(string? config) - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = config, - Template = "template" - }; - - Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); - } - - [Fact] - public void IsValidForType_NullDatadogConfiguration_ReturnsTrue() - { - var model = new OrganizationIntegrationConfigurationRequestModel - { - Configuration = null, - Template = "template" - }; - Assert.True(condition: model.IsValidForType(IntegrationType.Datadog)); + Assert.True(condition: model.IsValidForType(IntegrationType.Teams)); } [Theory] @@ -107,6 +85,8 @@ public class OrganizationIntegrationConfigurationRequestModelTests Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); + Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); + Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); } [Fact] @@ -121,6 +101,8 @@ public class OrganizationIntegrationConfigurationRequestModelTests Assert.False(condition: model.IsValidForType(IntegrationType.Slack)); Assert.False(condition: model.IsValidForType(IntegrationType.Webhook)); Assert.False(condition: model.IsValidForType(IntegrationType.Hec)); + Assert.False(condition: model.IsValidForType(IntegrationType.Datadog)); + Assert.False(condition: model.IsValidForType(IntegrationType.Teams)); } diff --git a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs index 81927a1bfe..76e206abf4 100644 --- a/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs +++ b/test/Api.Test/AdminConsole/Models/Request/Organizations/OrganizationIntegrationRequestModelTests.cs @@ -1,14 +1,47 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json; using Bit.Api.AdminConsole.Models.Request.Organizations; +using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.EventIntegrations; using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; using Xunit; namespace Bit.Api.Test.AdminConsole.Models.Request.Organizations; public class OrganizationIntegrationRequestModelTests { + [Fact] + public void ToOrganizationIntegration_CreatesNewOrganizationIntegration() + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Hec, + Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token")) + }; + + var organizationId = Guid.NewGuid(); + var organizationIntegration = model.ToOrganizationIntegration(organizationId); + + Assert.Equal(organizationIntegration.Type, model.Type); + Assert.Equal(organizationIntegration.Configuration, model.Configuration); + Assert.Equal(organizationIntegration.OrganizationId, organizationId); + } + + [Theory, BitAutoData] + public void ToOrganizationIntegration_UpdatesExistingOrganizationIntegration(OrganizationIntegration integration) + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Hec, + Configuration = JsonSerializer.Serialize(new HecIntegration(Uri: new Uri("http://localhost"), Scheme: "Bearer", Token: "Token")) + }; + + var organizationIntegration = model.ToOrganizationIntegration(integration); + + Assert.Equal(organizationIntegration.Configuration, model.Configuration); + } + [Fact] public void Validate_CloudBillingSync_ReturnsNotYetSupportedError() { @@ -57,6 +90,22 @@ public class OrganizationIntegrationRequestModelTests Assert.Contains("cannot be created directly", results[0].ErrorMessage); } + [Fact] + public void Validate_Teams_ReturnsCannotBeCreatedDirectlyError() + { + var model = new OrganizationIntegrationRequestModel + { + Type = IntegrationType.Teams, + Configuration = null + }; + + var results = model.Validate(new ValidationContext(model)).ToList(); + + Assert.Single(results); + Assert.Contains(nameof(model.Type), results[0].MemberNames); + Assert.Contains("cannot be created directly", results[0].ErrorMessage); + } + [Fact] public void Validate_Webhook_WithNullConfiguration_ReturnsNoErrors() { diff --git a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs index 057680425a..163d66aeb4 100644 --- a/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs +++ b/test/Api.Test/AdminConsole/Models/Request/SavePolicyRequestTests.cs @@ -24,11 +24,11 @@ public class SavePolicyRequestTests currentContext.OrganizationOwner(organizationId).Returns(true); var testData = new Dictionary { { "test", "value" } }; + var policyType = PolicyType.TwoFactorAuthentication; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.TwoFactorAuthentication, Enabled = true, Data = testData }, @@ -36,7 +36,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Equal(PolicyType.TwoFactorAuthentication, result.PolicyUpdate.Type); @@ -54,7 +54,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_WithNullData_HandlesCorrectly( + public async Task ToSavePolicyModelAsync_WithEmptyData_HandlesCorrectly( Guid organizationId, Guid userId) { @@ -63,19 +63,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(false); + var policyType = PolicyType.SingleOrg; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Null(result.PolicyUpdate.Data); @@ -95,19 +93,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.SingleOrg; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.SingleOrg, - Enabled = false, - Data = null - }, - Metadata = null + Enabled = false + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.Null(result.PolicyUpdate.Data); @@ -128,13 +124,12 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -143,7 +138,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.IsType(result.Metadata); @@ -152,7 +147,7 @@ public class SavePolicyRequestTests } [Theory, BitAutoData] - public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithNullMetadata_ReturnsEmptyMetadata( + public async Task ToSavePolicyModelAsync_OrganizationDataOwnership_WithEmptyMetadata_ReturnsEmptyMetadata( Guid organizationId, Guid userId) { @@ -161,19 +156,17 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null - }, - Metadata = null + Enabled = true + } }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); @@ -200,12 +193,11 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); - + var policyType = PolicyType.ResetPassword; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.ResetPassword, Enabled = true, Data = _complexData }, @@ -213,7 +205,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert var deserializedData = JsonSerializer.Deserialize>(result.PolicyUpdate.Data); @@ -241,13 +233,12 @@ public class SavePolicyRequestTests currentContext.UserId.Returns(userId); currentContext.OrganizationOwner(organizationId).Returns(true); + var policyType = PolicyType.MaximumVaultTimeout; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.MaximumVaultTimeout, - Enabled = true, - Data = null + Enabled = true }, Metadata = new Dictionary { @@ -256,7 +247,7 @@ public class SavePolicyRequestTests }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); @@ -274,20 +265,18 @@ public class SavePolicyRequestTests currentContext.OrganizationOwner(organizationId).Returns(true); var errorDictionary = BuildErrorDictionary(); - + var policyType = PolicyType.OrganizationDataOwnership; var model = new SavePolicyRequest { Policy = new PolicyRequestModel { - Type = PolicyType.OrganizationDataOwnership, - Enabled = true, - Data = null + Enabled = true }, Metadata = errorDictionary }; // Act - var result = await model.ToSavePolicyModelAsync(organizationId, currentContext); + var result = await model.ToSavePolicyModelAsync(organizationId, policyType, currentContext); // Assert Assert.NotNull(result); diff --git a/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs new file mode 100644 index 0000000000..28bc07de38 --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/Organizations/OrganizationIntegrationResponseModelTests.cs @@ -0,0 +1,160 @@ +#nullable enable + +using System.Text.Json; +using Bit.Api.AdminConsole.Models.Response.Organizations; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Models.Data.EventIntegrations; +using Bit.Core.Enums; +using Bit.Core.Models.Teams; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response.Organizations; + +public class OrganizationIntegrationResponseModelTests +{ + [Theory, BitAutoData] + public void Status_CloudBillingSync_AlwaysNotApplicable(OrganizationIntegration oi) + { + oi.Type = IntegrationType.CloudBillingSync; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + } + + [Theory, BitAutoData] + public void Status_Scim_AlwaysNotApplicable(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Scim; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.NotApplicable, model.Status); + } + + [Theory, BitAutoData] + public void Status_Slack_NullConfig_ReturnsInitiated(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Slack; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Initiated, model.Status); + } + + [Theory, BitAutoData] + public void Status_Slack_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Slack; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Teams_NullConfig_ReturnsInitiated(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Initiated, model.Status); + } + + [Theory, BitAutoData] + public void Status_Teams_WithTenantAndTeamsConfig_ReturnsInProgress(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = JsonSerializer.Serialize(new TeamsIntegration( + TenantId: "tenant", Teams: [new TeamInfo() { DisplayName = "Team", Id = "TeamId", TenantId = "tenant" }] + )); + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.InProgress, model.Status); + } + + [Theory, BitAutoData] + public void Status_Teams_WithCompletedConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Teams; + oi.Configuration = JsonSerializer.Serialize(new TeamsIntegration( + TenantId: "tenant", + Teams: [new TeamInfo() { DisplayName = "Team", Id = "TeamId", TenantId = "tenant" }], + ServiceUrl: new Uri("https://example.com"), + ChannelId: "channellId" + )); + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Webhook_AlwaysCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Webhook; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + + model.Configuration = "{}"; + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Hec_NullConfig_ReturnsInvalid(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Hec; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status); + } + + [Theory, BitAutoData] + public void Status_Hec_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Hec; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } + + [Theory, BitAutoData] + public void Status_Datadog_NullConfig_ReturnsInvalid(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Datadog; + oi.Configuration = null; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Invalid, model.Status); + } + + [Theory, BitAutoData] + public void Status_Datadog_WithConfig_ReturnsCompleted(OrganizationIntegration oi) + { + oi.Type = IntegrationType.Datadog; + oi.Configuration = "{}"; + + var model = new OrganizationIntegrationResponseModel(oi); + + Assert.Equal(OrganizationIntegrationStatus.Completed, model.Status); + } +} diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs new file mode 100644 index 0000000000..c2893c9fce --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileOrganizationResponseModelTests.cs @@ -0,0 +1,150 @@ +using Bit.Api.AdminConsole.Models.Response; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Utilities; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response; + +public class ProfileOrganizationResponseModelTests +{ + [Theory, BitAutoData] + public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization) + { + var userId = Guid.NewGuid(); + var organizationUserId = Guid.NewGuid(); + var providerId = Guid.NewGuid(); + var organizationIdsClaimingUser = new[] { organization.Id }; + + var organizationDetails = new OrganizationUserOrganizationDetails + { + OrganizationId = organization.Id, + UserId = userId, + OrganizationUserId = organizationUserId, + Name = organization.Name, + Enabled = organization.Enabled, + Identifier = organization.Identifier, + PlanType = organization.PlanType, + UsePolicies = organization.UsePolicies, + UseSso = organization.UseSso, + UseKeyConnector = organization.UseKeyConnector, + UseScim = organization.UseScim, + UseGroups = organization.UseGroups, + UseDirectory = organization.UseDirectory, + UseEvents = organization.UseEvents, + UseTotp = organization.UseTotp, + Use2fa = organization.Use2fa, + UseApi = organization.UseApi, + UseResetPassword = organization.UseResetPassword, + UseSecretsManager = organization.UseSecretsManager, + UsePasswordManager = organization.UsePasswordManager, + UsersGetPremium = organization.UsersGetPremium, + UseCustomPermissions = organization.UseCustomPermissions, + UseRiskInsights = organization.UseRiskInsights, + UseOrganizationDomains = organization.UseOrganizationDomains, + UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, + SelfHost = organization.SelfHost, + Seats = organization.Seats, + MaxCollections = organization.MaxCollections, + MaxStorageGb = organization.MaxStorageGb, + Key = "organization-key", + PublicKey = "public-key", + PrivateKey = "private-key", + LimitCollectionCreation = organization.LimitCollectionCreation, + LimitCollectionDeletion = organization.LimitCollectionDeletion, + LimitItemDeletion = organization.LimitItemDeletion, + AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems, + ProviderId = providerId, + ProviderName = "Test Provider", + ProviderType = ProviderType.Msp, + SsoEnabled = true, + SsoConfig = new SsoConfigurationData + { + MemberDecryptionType = MemberDecryptionType.KeyConnector, + KeyConnectorUrl = "https://keyconnector.example.com" + }.Serialize(), + SsoExternalId = "external-sso-id", + Permissions = CoreHelpers.ClassToJsonData(new Core.Models.Data.Permissions { ManageUsers = true }), + ResetPasswordKey = "reset-password-key", + FamilySponsorshipFriendlyName = "Family Sponsorship", + FamilySponsorshipLastSyncDate = DateTime.UtcNow.AddDays(-1), + FamilySponsorshipToDelete = false, + FamilySponsorshipValidUntil = DateTime.UtcNow.AddYears(1), + IsAdminInitiated = true, + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.Owner, + AccessSecretsManager = true, + SmSeats = 5, + SmServiceAccounts = 10 + }; + + var result = new ProfileOrganizationResponseModel(organizationDetails, organizationIdsClaimingUser); + + Assert.Equal("profileOrganization", result.Object); + Assert.Equal(organization.Id, result.Id); + Assert.Equal(userId, result.UserId); + Assert.Equal(organization.Name, result.Name); + Assert.Equal(organization.Enabled, result.Enabled); + Assert.Equal(organization.Identifier, result.Identifier); + Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType); + Assert.Equal(organization.UsePolicies, result.UsePolicies); + Assert.Equal(organization.UseSso, result.UseSso); + Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector); + Assert.Equal(organization.UseScim, result.UseScim); + Assert.Equal(organization.UseGroups, result.UseGroups); + Assert.Equal(organization.UseDirectory, result.UseDirectory); + Assert.Equal(organization.UseEvents, result.UseEvents); + Assert.Equal(organization.UseTotp, result.UseTotp); + Assert.Equal(organization.Use2fa, result.Use2fa); + Assert.Equal(organization.UseApi, result.UseApi); + Assert.Equal(organization.UseResetPassword, result.UseResetPassword); + Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager); + Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager); + Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium); + Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions); + Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy); + Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights); + Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains); + Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies); + Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation); + Assert.Equal(organization.SelfHost, result.SelfHost); + Assert.Equal(organization.Seats, result.Seats); + Assert.Equal(organization.MaxCollections, result.MaxCollections); + Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb); + Assert.Equal(organizationDetails.Key, result.Key); + Assert.True(result.HasPublicAndPrivateKeys); + Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation); + Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion); + Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion); + Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems); + Assert.Equal(organizationDetails.ProviderId, result.ProviderId); + Assert.Equal(organizationDetails.ProviderName, result.ProviderName); + Assert.Equal(organizationDetails.ProviderType, result.ProviderType); + Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled); + Assert.True(result.KeyConnectorEnabled); + Assert.Equal("https://keyconnector.example.com", result.KeyConnectorUrl); + Assert.Equal(MemberDecryptionType.KeyConnector, result.SsoMemberDecryptionType); + Assert.True(result.SsoBound); + Assert.Equal(organizationDetails.Status, result.Status); + Assert.Equal(organizationDetails.Type, result.Type); + Assert.Equal(organizationDetails.OrganizationUserId, result.OrganizationUserId); + Assert.True(result.UserIsClaimedByOrganization); + Assert.NotNull(result.Permissions); + Assert.True(result.ResetPasswordEnrolled); + Assert.Equal(organizationDetails.AccessSecretsManager, result.AccessSecretsManager); + Assert.Equal(organizationDetails.FamilySponsorshipFriendlyName, result.FamilySponsorshipFriendlyName); + Assert.Equal(organizationDetails.FamilySponsorshipLastSyncDate, result.FamilySponsorshipLastSyncDate); + Assert.Equal(organizationDetails.FamilySponsorshipToDelete, result.FamilySponsorshipToDelete); + Assert.Equal(organizationDetails.FamilySponsorshipValidUntil, result.FamilySponsorshipValidUntil); + Assert.True(result.IsAdminInitiated); + Assert.False(result.FamilySponsorshipAvailable); + } +} diff --git a/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs new file mode 100644 index 0000000000..a131f90724 --- /dev/null +++ b/test/Api.Test/AdminConsole/Models/Response/ProfileProviderOrganizationResponseModelTests.cs @@ -0,0 +1,129 @@ +using Bit.Api.AdminConsole.Models.Response; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums.Provider; +using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Data; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Extensions; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Models.Response; + +public class ProfileProviderOrganizationResponseModelTests +{ + [Theory, BitAutoData] + public void Constructor_ShouldPopulatePropertiesCorrectly(Organization organization) + { + var userId = Guid.NewGuid(); + var providerId = Guid.NewGuid(); + var providerUserId = Guid.NewGuid(); + + var organizationDetails = new ProviderUserOrganizationDetails + { + OrganizationId = organization.Id, + UserId = userId, + Name = organization.Name, + Enabled = organization.Enabled, + Identifier = organization.Identifier, + PlanType = organization.PlanType, + UsePolicies = organization.UsePolicies, + UseSso = organization.UseSso, + UseKeyConnector = organization.UseKeyConnector, + UseScim = organization.UseScim, + UseGroups = organization.UseGroups, + UseDirectory = organization.UseDirectory, + UseEvents = organization.UseEvents, + UseTotp = organization.UseTotp, + Use2fa = organization.Use2fa, + UseApi = organization.UseApi, + UseResetPassword = organization.UseResetPassword, + UseSecretsManager = organization.UseSecretsManager, + UsePasswordManager = organization.UsePasswordManager, + UsersGetPremium = organization.UsersGetPremium, + UseCustomPermissions = organization.UseCustomPermissions, + UseRiskInsights = organization.UseRiskInsights, + UseOrganizationDomains = organization.UseOrganizationDomains, + UseAdminSponsoredFamilies = organization.UseAdminSponsoredFamilies, + UseAutomaticUserConfirmation = organization.UseAutomaticUserConfirmation, + SelfHost = organization.SelfHost, + Seats = organization.Seats, + MaxCollections = organization.MaxCollections, + MaxStorageGb = organization.MaxStorageGb, + Key = "provider-org-key", + PublicKey = "public-key", + PrivateKey = "private-key", + LimitCollectionCreation = organization.LimitCollectionCreation, + LimitCollectionDeletion = organization.LimitCollectionDeletion, + LimitItemDeletion = organization.LimitItemDeletion, + AllowAdminAccessToAllCollectionItems = organization.AllowAdminAccessToAllCollectionItems, + ProviderId = providerId, + ProviderName = "Test MSP Provider", + ProviderType = ProviderType.Msp, + SsoEnabled = true, + SsoConfig = new SsoConfigurationData + { + MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption + }.Serialize(), + Status = ProviderUserStatusType.Confirmed, + Type = ProviderUserType.ProviderAdmin, + ProviderUserId = providerUserId + }; + + var result = new ProfileProviderOrganizationResponseModel(organizationDetails); + + Assert.Equal("profileProviderOrganization", result.Object); + Assert.Equal(organization.Id, result.Id); + Assert.Equal(userId, result.UserId); + Assert.Equal(organization.Name, result.Name); + Assert.Equal(organization.Enabled, result.Enabled); + Assert.Equal(organization.Identifier, result.Identifier); + Assert.Equal(organization.PlanType.GetProductTier(), result.ProductTierType); + Assert.Equal(organization.UsePolicies, result.UsePolicies); + Assert.Equal(organization.UseSso, result.UseSso); + Assert.Equal(organization.UseKeyConnector, result.UseKeyConnector); + Assert.Equal(organization.UseScim, result.UseScim); + Assert.Equal(organization.UseGroups, result.UseGroups); + Assert.Equal(organization.UseDirectory, result.UseDirectory); + Assert.Equal(organization.UseEvents, result.UseEvents); + Assert.Equal(organization.UseTotp, result.UseTotp); + Assert.Equal(organization.Use2fa, result.Use2fa); + Assert.Equal(organization.UseApi, result.UseApi); + Assert.Equal(organization.UseResetPassword, result.UseResetPassword); + Assert.Equal(organization.UseSecretsManager, result.UseSecretsManager); + Assert.Equal(organization.UsePasswordManager, result.UsePasswordManager); + Assert.Equal(organization.UsersGetPremium, result.UsersGetPremium); + Assert.Equal(organization.UseCustomPermissions, result.UseCustomPermissions); + Assert.Equal(organization.PlanType.GetProductTier() == ProductTierType.Enterprise, result.UseActivateAutofillPolicy); + Assert.Equal(organization.UseRiskInsights, result.UseRiskInsights); + Assert.Equal(organization.UseOrganizationDomains, result.UseOrganizationDomains); + Assert.Equal(organization.UseAdminSponsoredFamilies, result.UseAdminSponsoredFamilies); + Assert.Equal(organization.UseAutomaticUserConfirmation, result.UseAutomaticUserConfirmation); + Assert.Equal(organization.SelfHost, result.SelfHost); + Assert.Equal(organization.Seats, result.Seats); + Assert.Equal(organization.MaxCollections, result.MaxCollections); + Assert.Equal(organization.MaxStorageGb, result.MaxStorageGb); + Assert.Equal(organizationDetails.Key, result.Key); + Assert.True(result.HasPublicAndPrivateKeys); + Assert.Equal(organization.LimitCollectionCreation, result.LimitCollectionCreation); + Assert.Equal(organization.LimitCollectionDeletion, result.LimitCollectionDeletion); + Assert.Equal(organization.LimitItemDeletion, result.LimitItemDeletion); + Assert.Equal(organization.AllowAdminAccessToAllCollectionItems, result.AllowAdminAccessToAllCollectionItems); + Assert.Equal(organizationDetails.ProviderId, result.ProviderId); + Assert.Equal(organizationDetails.ProviderName, result.ProviderName); + Assert.Equal(organizationDetails.ProviderType, result.ProviderType); + Assert.Equal(OrganizationUserStatusType.Confirmed, result.Status); + Assert.Equal(OrganizationUserType.Owner, result.Type); + Assert.Equal(organizationDetails.SsoEnabled, result.SsoEnabled); + Assert.False(result.KeyConnectorEnabled); + Assert.Null(result.KeyConnectorUrl); + Assert.Equal(MemberDecryptionType.TrustedDeviceEncryption, result.SsoMemberDecryptionType); + Assert.False(result.SsoBound); + Assert.NotNull(result.Permissions); + Assert.False(result.Permissions.ManageUsers); + Assert.False(result.ResetPasswordEnrolled); + Assert.False(result.AccessSecretsManager); + } +} diff --git a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs new file mode 100644 index 0000000000..c2360f5f9a --- /dev/null +++ b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -0,0 +1,87 @@ +using Bit.Api.AdminConsole.Public.Controllers; +using Bit.Api.AdminConsole.Public.Models.Request; +using Bit.Core; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; +using Bit.Core.Context; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.AdminConsole.Public.Controllers; + +[ControllerCustomize(typeof(PoliciesController))] +[SutProviderCustomize] +public class PoliciesControllerTests +{ + [Theory] + [BitAutoData] + public async Task Put_WhenPolicyValidatorsRefactorEnabled_UsesVNextSavePolicyCommand( + Guid organizationId, + PolicyType policyType, + PolicyUpdateRequestModel model, + Policy policy, + SutProvider sutProvider) + { + // Arrange + policy.Data = null; + sutProvider.GetDependency() + .OrganizationId.Returns(organizationId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) + .Returns(true); + sutProvider.GetDependency() + .SaveAsync(Arg.Any()) + .Returns(policy); + + // Act + await sutProvider.Sut.Put(policyType, model); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(m => + m.PolicyUpdate.OrganizationId == organizationId && + m.PolicyUpdate.Type == policyType && + m.PolicyUpdate.Enabled == model.Enabled.GetValueOrDefault() && + m.PerformedBy is SystemUser)); + } + + [Theory] + [BitAutoData] + public async Task Put_WhenPolicyValidatorsRefactorDisabled_UsesLegacySavePolicyCommand( + Guid organizationId, + PolicyType policyType, + PolicyUpdateRequestModel model, + Policy policy, + SutProvider sutProvider) + { + // Arrange + policy.Data = null; + sutProvider.GetDependency() + .OrganizationId.Returns(organizationId); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PolicyValidatorsRefactor) + .Returns(false); + sutProvider.GetDependency() + .SaveAsync(Arg.Any()) + .Returns(policy); + + // Act + await sutProvider.Sut.Put(policyType, model); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .SaveAsync(Arg.Is(p => + p.OrganizationId == organizationId && + p.Type == policyType && + p.Enabled == model.Enabled)); + } +} diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index e81d51281d..f1aa11d068 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -11,6 +11,8 @@ using Bit.Core.Auth.UserFeatures.UserMasterPassword.Interfaces; using Bit.Core.Entities; using Bit.Core.Exceptions; using Bit.Core.KeyManagement.Kdf; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Queries.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -33,10 +35,10 @@ public class AccountsControllerTests : IDisposable private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly ITdeOffboardingPasswordCommand _tdeOffboardingPasswordCommand; private readonly IFeatureService _featureService; + private readonly IUserAccountKeysQuery _userAccountKeysQuery; private readonly ITwoFactorEmailService _twoFactorEmailService; private readonly IChangeKdfCommand _changeKdfCommand; - public AccountsControllerTests() { _userService = Substitute.For(); @@ -48,6 +50,7 @@ public class AccountsControllerTests : IDisposable _twoFactorIsEnabledQuery = Substitute.For(); _tdeOffboardingPasswordCommand = Substitute.For(); _featureService = Substitute.For(); + _userAccountKeysQuery = Substitute.For(); _twoFactorEmailService = Substitute.For(); _changeKdfCommand = Substitute.For(); @@ -61,6 +64,7 @@ public class AccountsControllerTests : IDisposable _tdeOffboardingPasswordCommand, _twoFactorIsEnabledQuery, _featureService, + _userAccountKeysQuery, _twoFactorEmailService, _changeKdfCommand ); @@ -614,6 +618,16 @@ public class AccountsControllerTests : IDisposable await _twoFactorEmailService.Received(1).SendNewDeviceVerificationEmailAsync(user); } + [Theory] + [BitAutoData] + public async Task PostKdf_UserNotFound_ShouldFail(PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(null)); + + // Act + await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + } + [Theory] [BitAutoData] public async Task PostKdf_WithNullAuthenticationData_ShouldFail( @@ -623,7 +637,9 @@ public class AccountsControllerTests : IDisposable model.AuthenticationData = null; // Act - await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.Contains("AuthenticationData and UnlockData must be provided.", exception.Message); } [Theory] @@ -635,7 +651,41 @@ public class AccountsControllerTests : IDisposable model.UnlockData = null; // Act - await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.Contains("AuthenticationData and UnlockData must be provided.", exception.Message); + } + + [Theory] + [BitAutoData] + public async Task PostKdf_ChangeKdfFailed_ShouldFail( + User user, PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + _changeKdfCommand.ChangeKdfAsync(Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Failed(new IdentityError { Description = "Change KDF failed" }))); + + // Act + var exception = await Assert.ThrowsAsync(() => _sut.PostKdf(model)); + + Assert.NotNull(exception.ModelState); + Assert.Contains("Change KDF failed", + exception.ModelState.Values.SelectMany(x => x.Errors).Select(x => x.ErrorMessage)); + } + + [Theory] + [BitAutoData] + public async Task PostKdf_ChangeKdfSuccess_NoError( + User user, PasswordRequestModel model) + { + _userService.GetUserByPrincipalAsync(Arg.Any()).Returns(Task.FromResult(user)); + _changeKdfCommand.ChangeKdfAsync(Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(Task.FromResult(IdentityResult.Success)); + + // Act + await _sut.PostKdf(model); } // Below are helper functions that currently belong to this diff --git a/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs b/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs new file mode 100644 index 0000000000..8348ba885d --- /dev/null +++ b/test/Api.Test/Auth/Models/Request/OrganizationSsoRequestModelTests.cs @@ -0,0 +1,313 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Api.Auth.Models.Request.Organizations; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Services; +using Bit.Core.Sso; +using Microsoft.Extensions.Localization; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Auth.Models.Request; + +public class OrganizationSsoRequestModelTests +{ + [Fact] + public void ToSsoConfig_WithOrganizationId_CreatesNewSsoConfig() + { + // Arrange + var organizationId = Guid.NewGuid(); + var model = new OrganizationSsoRequestModel + { + Enabled = true, + Identifier = "test-identifier", + Data = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = "https://example.com", + ClientId = "test-client", + ClientSecret = "test-secret" + } + }; + + // Act + var result = model.ToSsoConfig(organizationId); + + // Assert + Assert.NotNull(result); + Assert.Equal(organizationId, result.OrganizationId); + Assert.True(result.Enabled); + } + + [Fact] + public void ToSsoConfig_WithExistingConfig_UpdatesExistingConfig() + { + // Arrange + var organizationId = Guid.NewGuid(); + var existingConfig = new SsoConfig + { + Id = 1, + OrganizationId = organizationId, + Enabled = false + }; + + var model = new OrganizationSsoRequestModel + { + Enabled = true, + Identifier = "updated-identifier", + Data = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = "test-entity", + IdpSingleSignOnServiceUrl = "https://sso.example.com" + } + }; + + // Act + var result = model.ToSsoConfig(existingConfig); + + // Assert + Assert.Same(existingConfig, result); + Assert.Equal(organizationId, result.OrganizationId); + Assert.True(result.Enabled); + } +} + +public class SsoConfigurationDataRequestTests +{ + private readonly TestI18nService _i18nService; + private readonly ValidationContext _validationContext; + + public SsoConfigurationDataRequestTests() + { + _i18nService = new TestI18nService(); + var serviceProvider = Substitute.For(); + serviceProvider.GetService(typeof(II18nService)).Returns(_i18nService); + _validationContext = new ValidationContext(new object(), serviceProvider, null); + } + + [Fact] + public void ToConfigurationData_MapsProperties() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + MemberDecryptionType = MemberDecryptionType.KeyConnector, + Authority = "https://authority.example.com", + ClientId = "test-client-id", + ClientSecret = "test-client-secret", + IdpX509PublicCert = "-----BEGIN CERTIFICATE-----\nMIIC...test\n-----END CERTIFICATE-----", + SpOutboundSigningAlgorithm = null // Test default + }; + + // Act + var result = model.ToConfigurationData(); + + // Assert + Assert.Equal(SsoType.OpenIdConnect, result.ConfigType); + Assert.Equal(MemberDecryptionType.KeyConnector, result.MemberDecryptionType); + Assert.Equal("https://authority.example.com", result.Authority); + Assert.Equal("test-client-id", result.ClientId); + Assert.Equal("test-client-secret", result.ClientSecret); + Assert.Equal("MIIC...test", result.IdpX509PublicCert); // PEM headers stripped + Assert.Equal(SamlSigningAlgorithms.Sha256, result.SpOutboundSigningAlgorithm); // Default applied + Assert.Null(result.IdpArtifactResolutionServiceUrl); // Always null + } + + [Fact] + public void KeyConnectorEnabled_Setter_UpdatesMemberDecryptionType() + { + // Arrange + var model = new SsoConfigurationDataRequest(); + + // Act & Assert +#pragma warning disable CS0618 // Type or member is obsolete + model.KeyConnectorEnabled = true; + Assert.Equal(MemberDecryptionType.KeyConnector, model.MemberDecryptionType); + + model.KeyConnectorEnabled = false; + Assert.Equal(MemberDecryptionType.MasterPassword, model.MemberDecryptionType); +#pragma warning restore CS0618 // Type or member is obsolete + } + + // Validation Tests + [Fact] + public void Validate_OpenIdConnect_ValidData_NoErrors() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = "https://example.com", + ClientId = "test-client", + ClientSecret = "test-secret" + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Empty(results); + } + + [Theory] + [InlineData("", "test-client", "test-secret", "AuthorityValidationError")] + [InlineData("https://example.com", "", "test-secret", "ClientIdValidationError")] + [InlineData("https://example.com", "test-client", "", "ClientSecretValidationError")] + public void Validate_OpenIdConnect_MissingRequiredFields_ReturnsErrors(string authority, string clientId, string clientSecret, string expectedError) + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.OpenIdConnect, + Authority = authority, + ClientId = clientId, + ClientSecret = clientSecret + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Single(results); + Assert.Equal(expectedError, results[0].ErrorMessage); + } + + [Fact] + public void Validate_Saml2_ValidData_NoErrors() + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = "https://idp.example.com", + IdpSingleSignOnServiceUrl = "https://sso.example.com", + IdpSingleLogoutServiceUrl = "https://logout.example.com" + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Empty(results); + } + + [Theory] + [InlineData("", "https://sso.example.com", "IdpEntityIdValidationError")] + [InlineData("not-a-valid-uri", "", "IdpSingleSignOnServiceUrlValidationError")] + public void Validate_Saml2_MissingRequiredFields_ReturnsErrors(string entityId, string signOnUrl, string expectedError) + { + // Arrange + var model = new SsoConfigurationDataRequest + { + ConfigType = SsoType.Saml2, + IdpEntityId = entityId, + IdpSingleSignOnServiceUrl = signOnUrl + }; + + // Act + var results = model.Validate(_validationContext).ToList(); + + // Assert + Assert.Contains(results, r => r.ErrorMessage == expectedError); + } + + [Theory] + [InlineData("not-a-url")] + [InlineData("ftp://example.com")] + [InlineData("https://example.com