diff --git a/.checkmarx/config.yml b/.checkmarx/config.yml index 641da0eacb..e40c43b662 100644 --- a/.checkmarx/config.yml +++ b/.checkmarx/config.yml @@ -11,3 +11,7 @@ checkmarx: filter: "!test" kics: filter: "!dev,!.devcontainer" + sca: + filter: "!dev,!.devcontainer" + containers: + filter: "!dev,!.devcontainer" diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000000..8cf0d87c5c --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,10 @@ +{ + "extraKnownMarketplaces": { + "bitwarden-marketplace": { + "source": { + "source": "github", + "repo": "bitwarden/ai-plugins" + } + } + } +} diff --git a/.config/dotnet-tools.json b/.config/dotnet-tools.json index 227f59ad8a..846414bf49 100644 --- a/.config/dotnet-tools.json +++ b/.config/dotnet-tools.json @@ -3,7 +3,7 @@ "isRoot": true, "tools": { "swashbuckle.aspnetcore.cli": { - "version": "9.0.4", + "version": "10.1.0", "commands": ["swagger"] }, "dotnet-ef": { diff --git a/.devcontainer/community_dev/devcontainer.json b/.devcontainer/community_dev/devcontainer.json index ce3b8a21c6..c59ad3b839 100644 --- a/.devcontainer/community_dev/devcontainer.json +++ b/.devcontainer/community_dev/devcontainer.json @@ -3,10 +3,12 @@ "dockerComposeFile": "../../.devcontainer/bitwarden_common/docker-compose.yml", "service": "bitwarden_server", "workspaceFolder": "/workspace", + "initializeCommand": "mkdir -p dev/.data/keys dev/.data/mssql dev/.data/azurite dev/helpers/mssql", "features": { "ghcr.io/devcontainers/features/node:1": { - "version": "16" - } + "version": "22" + }, + "ghcr.io/devcontainers/features/rust:1": {} }, "mounts": [ { @@ -21,5 +23,27 @@ "extensions": ["ms-dotnettools.csdevkit"] } }, - "postCreateCommand": "bash .devcontainer/community_dev/postCreateCommand.sh" + "postCreateCommand": "bash .devcontainer/community_dev/postCreateCommand.sh", + "forwardPorts": [1080, 1433, 3306, 5432], + "portsAttributes": { + "default": { + "onAutoForward": "ignore" + }, + "1080": { + "label": "Mail Catcher", + "onAutoForward": "notify" + }, + "1433": { + "label": "SQL Server", + "onAutoForward": "notify" + }, + "3306": { + "label": "MySQL", + "onAutoForward": "notify" + }, + "5432": { + "label": "PostgreSQL", + "onAutoForward": "notify" + } + } } diff --git a/.devcontainer/community_dev/postCreateCommand.sh b/.devcontainer/community_dev/postCreateCommand.sh index 8f1813ed78..8ae3854168 100755 --- a/.devcontainer/community_dev/postCreateCommand.sh +++ b/.devcontainer/community_dev/postCreateCommand.sh @@ -3,11 +3,46 @@ export DEV_DIR=/workspace/dev export CONTAINER_CONFIG=/workspace/.devcontainer/community_dev git config --global --add safe.directory /workspace +if [[ -z "${CODESPACES}" ]]; then + allow_interactive=1 +else + echo "Doing non-interactive setup" + allow_interactive=0 +fi + +get_option() { + # Helper function for reading the value of an environment variable + # primarily but then falling back to an interactive question if allowed + # and lastly falling back to a default value input when either other + # option is available. + name_of_var="$1" + question_text="$2" + default_value="$3" + is_secret="$4" + + if [[ -n "${!name_of_var}" ]]; then + # If the env variable they gave us has a value, then use that value + echo "${!name_of_var}" + elif [[ "$allow_interactive" == 1 ]]; then + # If we can be interactive, then use the text they gave us to request input + if [[ "$is_secret" == 1 ]]; then + read -r -s -p "$question_text" response + echo "$response" + else + read -r -p "$question_text" response + echo "$response" + fi + else + # If no environment variable and not interactive, then just give back default value + echo "$default_value" + fi +} + get_installation_id_and_key() { pushd ./dev >/dev/null || exit echo "Please enter your installation id and key from https://bitwarden.com/host:" - read -r -p "Installation id: " INSTALLATION_ID - read -r -p "Installation key: " INSTALLATION_KEY + INSTALLATION_ID="$(get_option "INSTALLATION_ID" "Installation id: " "00000000-0000-0000-0000-000000000001")" + INSTALLATION_KEY="$(get_option "INSTALLATION_KEY" "Installation key: " "" 1)" jq ".globalSettings.installation.id = \"$INSTALLATION_ID\" | .globalSettings.installation.key = \"$INSTALLATION_KEY\"" \ secrets.json.example >secrets.json # create/overwrite secrets.json @@ -30,11 +65,10 @@ configure_other_vars() { } one_time_setup() { - read -r -p \ - "Would you like to configure your secrets and certificates for the first time? + do_secrets_json_setup="$(get_option "SETUP_SECRETS_JSON" "Would you like to configure your secrets and certificates for the first time? WARNING: This will overwrite any existing secrets.json and certificate files. -Proceed? [y/N] " response - if [[ "$response" =~ ^([yY][eE][sS]|[yY])+$ ]]; then +Proceed? [y/N] " "n")" + if [[ "$do_secrets_json_setup" =~ ^([yY][eE][sS]|[yY])+$ ]]; then echo "Running one-time setup script..." sleep 1 get_installation_id_and_key @@ -50,11 +84,4 @@ Proceed? [y/N] " response fi } -# main -if [[ -z "${CODESPACES}" ]]; then - one_time_setup -else - # Ignore interactive elements when running in codespaces since they are not supported there - # TODO Write codespaces specific instructions and link here - echo "Running in codespaces, follow instructions here: https://contributing.bitwarden.com/getting-started/server/guide/ to continue the setup" -fi +one_time_setup diff --git a/.devcontainer/internal_dev/devcontainer.json b/.devcontainer/internal_dev/devcontainer.json index 862b9297c4..99e3057024 100644 --- a/.devcontainer/internal_dev/devcontainer.json +++ b/.devcontainer/internal_dev/devcontainer.json @@ -6,10 +6,12 @@ ], "service": "bitwarden_server", "workspaceFolder": "/workspace", + "initializeCommand": "mkdir -p dev/.data/keys dev/.data/mssql dev/.data/azurite dev/helpers/mssql", "features": { "ghcr.io/devcontainers/features/node:1": { - "version": "16" - } + "version": "22" + }, + "ghcr.io/devcontainers/features/rust:1": {} }, "mounts": [ { @@ -24,9 +26,18 @@ "extensions": ["ms-dotnettools.csdevkit"] } }, + "onCreateCommand": "bash .devcontainer/internal_dev/onCreateCommand.sh", "postCreateCommand": "bash .devcontainer/internal_dev/postCreateCommand.sh", - "forwardPorts": [1080, 1433, 3306, 5432, 10000, 10001, 10002], + "forwardPorts": [ + 1080, 1433, 3306, 5432, 10000, 10001, 10002, + 4000, 4001, 33656, 33657, 44519, 44559, + 46273, 46274, 50024, 51822, 51823, + 54103, 61840, 61841, 62911, 62912 + ], "portsAttributes": { + "default": { + "onAutoForward": "ignore" + }, "1080": { "label": "Mail Catcher", "onAutoForward": "notify" @@ -48,12 +59,76 @@ "onAutoForward": "notify" }, "10001": { - "label": "Azurite Storage Queue ", + "label": "Azurite Storage Queue", "onAutoForward": "notify" }, "10002": { "label": "Azurite Storage Table", "onAutoForward": "notify" + }, + "4000": { + "label": "Api (Cloud)", + "onAutoForward": "notify" + }, + "4001": { + "label": "Api (SelfHost)", + "onAutoForward": "notify" + }, + "33656": { + "label": "Identity (Cloud)", + "onAutoForward": "notify" + }, + "33657": { + "label": "Identity (SelfHost)", + "onAutoForward": "notify" + }, + "44519": { + "label": "Billing", + "onAutoForward": "notify" + }, + "44559": { + "label": "Scim", + "onAutoForward": "notify" + }, + "46273": { + "label": "Events (Cloud)", + "onAutoForward": "notify" + }, + "46274": { + "label": "Events (SelfHost)", + "onAutoForward": "notify" + }, + "50024": { + "label": "Icons", + "onAutoForward": "notify" + }, + "51822": { + "label": "Sso (Cloud)", + "onAutoForward": "notify" + }, + "51823": { + "label": "Sso (SelfHost)", + "onAutoForward": "notify" + }, + "54103": { + "label": "EventsProcessor", + "onAutoForward": "notify" + }, + "61840": { + "label": "Notifications (Cloud)", + "onAutoForward": "notify" + }, + "61841": { + "label": "Notifications (SelfHost)", + "onAutoForward": "notify" + }, + "62911": { + "label": "Admin (Cloud)", + "onAutoForward": "notify" + }, + "62912": { + "label": "Admin (SelfHost)", + "onAutoForward": "notify" } } } diff --git a/.devcontainer/internal_dev/onCreateCommand.sh b/.devcontainer/internal_dev/onCreateCommand.sh new file mode 100644 index 0000000000..71d466aae9 --- /dev/null +++ b/.devcontainer/internal_dev/onCreateCommand.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +export REPO_ROOT="$(git rev-parse --show-toplevel)" + +file="$REPO_ROOT/dev/custom-root-ca.crt" + +if [ -e "$file" ]; then + echo "Adding custom root CA" + sudo cp "$file" /usr/local/share/ca-certificates/ + sudo update-ca-certificates +else + echo "No custom root CA found, skipping..." +fi diff --git a/.devcontainer/internal_dev/postCreateCommand.sh b/.devcontainer/internal_dev/postCreateCommand.sh index 3fd278be26..ceef0ef0f5 100755 --- a/.devcontainer/internal_dev/postCreateCommand.sh +++ b/.devcontainer/internal_dev/postCreateCommand.sh @@ -108,7 +108,7 @@ Press to continue." fi run_mssql_migrations="$(get_option "RUN_MSSQL_MIGRATIONS" "Would you like us to run MSSQL Migrations for you? [y/N] " "n")" - if [[ "$do_azurite_setup" =~ ^([yY][eE][sS]|[yY])+$ ]]; then + if [[ "$run_mssql_migrations" =~ ^([yY][eE][sS]|[yY])+$ ]]; then echo "Running migrations..." sleep 5 # wait for DB container to start dotnet run --project "$REPO_ROOT/util/MsSqlMigratorUtility" "$SQL_CONNECTION_STRING" diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f0c85d98c1..c3e95e724b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -11,6 +11,9 @@ **/docker-compose.yml @bitwarden/team-appsec @bitwarden/dept-bre **/entrypoint.sh @bitwarden/team-appsec @bitwarden/dept-bre +# Scanning tools +.checkmarx/ @bitwarden/team-appsec + ## BRE team owns these workflows ## .github/workflows/publish.yml @bitwarden/dept-bre @@ -94,9 +97,7 @@ src/Admin/Views/Tools @bitwarden/team-billing-dev .github/workflows/test-database.yml @bitwarden/team-platform-dev .github/workflows/test.yml @bitwarden/team-platform-dev **/*Platform* @bitwarden/team-platform-dev -**/.dockerignore @bitwarden/team-platform-dev -**/Dockerfile @bitwarden/team-platform-dev -**/entrypoint.sh @bitwarden/team-platform-dev + # The PushType enum is expected to be editted by anyone without need for Platform review src/Core/Platform/Push/PushType.cs diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index edbc9d98cc..224020991d 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -9,27 +9,3 @@ ## 📸 Screenshots - -## ⏰ Reminders before review - -- Contributor guidelines followed -- All formatters and local linters executed and passed -- Written new unit and / or integration tests where applicable -- Protected functional changes with optionality (feature flags) -- Used internationalization (i18n) for all UI strings -- CI builds passed -- Communicated to DevOps any deployment requirements -- Updated any necessary documentation (Confluence, contributing docs) or informed the documentation team - -## 🦮 Reviewer guidelines - - - -- 👍 (`:+1:`) or similar for great changes -- 📝 (`:memo:`) or ℹ️ (`:information_source:`) for notes or general info -- ❓ (`:question:`) for questions -- 🤔 (`:thinking:`) or 💭 (`:thought_balloon:`) for more open inquiry that's not quite a confirmed issue and could potentially benefit from discussion -- 🎨 (`:art:`) for suggestions / improvements -- ❌ (`:x:`) or ⚠️ (`:warning:`) for more significant problems or concerns needing attention -- 🌱 (`:seedling:`) or ♻️ (`:recycle:`) for future improvements or indications of technical debt -- ⛏ (`:pick:`) for minor or nitpick changes diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 77539ef839..0796c4dbdf 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -21,12 +21,6 @@ commitMessagePrefix: "[deps] AC:", reviewers: ["team:team-admin-console-dev"], }, - { - matchFileNames: ["src/Admin/package.json", "src/Sso/package.json"], - description: "Admin & SSO npm packages", - commitMessagePrefix: "[deps] Auth:", - reviewers: ["team:team-auth-dev"], - }, { matchPackageNames: [ "DuoUniversal", @@ -182,6 +176,14 @@ matchUpdateTypes: ["minor"], addLabels: ["hold"], }, + { + groupName: "Admin and SSO npm dependencies", + matchFileNames: ["src/Admin/package.json", "src/Sso/package.json"], + matchUpdateTypes: ["minor", "patch"], + description: "Admin & SSO npm packages", + commitMessagePrefix: "[deps] Auth:", + reviewers: ["team:team-auth-dev"], + }, { matchPackageNames: ["/^Microsoft\\.EntityFrameworkCore\\./", "/^dotnet-ef/"], groupName: "EntityFrameworkCore", diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a7717be4e8..bf9778651a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,7 +31,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Verify format run: dotnet format --verify-no-changes @@ -119,10 +119,10 @@ jobs: fi - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Set up Node - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6.1.0 + uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: cache: "npm" cache-dependency-path: "**/package-lock.json" @@ -245,7 +245,7 @@ jobs: - name: Install Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' - uses: sigstore/cosign-installer@7e8b541eb2e61bf99390e1afd4be13a184e9ebc5 # v3.10.1 + uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0 - name: Sign image with Cosign if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' @@ -263,14 +263,14 @@ jobs: - name: Scan Docker image id: container-scan - uses: anchore/scan-action@3c9a191a0fbab285ca6b8530b5de5a642cba332f # v7.2.2 + uses: anchore/scan-action@0d444ed77d83ee2ba7f5ced0d90d640a1281d762 # v7.3.0 with: image: ${{ steps.image-tags.outputs.primary_tag }} fail-build: false output-format: sarif - name: Upload Grype results to GitHub - uses: github/codeql-action/upload-sarif@5d4e8d1aca955e8d8589aabd499c5cae939e33c7 # v4.31.9 + uses: github/codeql-action/upload-sarif@cdefb33c0f6224e58673d9004f47f7cb3e328b89 # v4.31.10 with: sarif_file: ${{ steps.container-scan.outputs.sarif }} sha: ${{ contains(github.event_name, 'pull_request') && github.event.pull_request.head.sha || github.sha }} @@ -294,7 +294,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Log in to Azure uses: bitwarden/gh-actions/azure-login@main @@ -420,7 +420,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Print environment run: | diff --git a/.github/workflows/test-database.yml b/.github/workflows/test-database.yml index 4630c18e40..25ff9d0488 100644 --- a/.github/workflows/test-database.yml +++ b/.github/workflows/test-database.yml @@ -49,7 +49,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Restore tools run: dotnet tool restore @@ -156,7 +156,7 @@ jobs: run: 'docker logs "$(docker ps --quiet --filter "name=mssql")"' - name: Report test results - uses: dorny/test-reporter@fe45e9537387dac839af0d33ba56eed8e24189e8 # v2.3.0 + uses: dorny/test-reporter@b082adf0eced0765477756c2a610396589b8c637 # v2.5.0 if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }} with: name: Test Results @@ -183,7 +183,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Print environment run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a6d07bb650..12b5355c33 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,7 +32,7 @@ jobs: persist-credentials: false - name: Set up .NET - uses: actions/setup-dotnet@2016bd2012dba4e32de620c46fe006a3ac9f0602 # v5.0.1 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v5.1.0 - name: Install rust uses: dtolnay/rust-toolchain@f7ccc83f9ed1e5b9c81d8a67d7ad1a747e22a561 # stable @@ -59,7 +59,7 @@ jobs: run: dotnet test ./bitwarden_license/test --configuration Debug --logger "trx;LogFileName=bw-test-results.trx" /p:CoverletOutputFormatter="cobertura" --collect:"XPlat Code Coverage" - name: Report test results - uses: dorny/test-reporter@fe45e9537387dac839af0d33ba56eed8e24189e8 # v2.3.0 + uses: dorny/test-reporter@b082adf0eced0765477756c2a610396589b8c637 # v2.5.0 if: ${{ github.event.pull_request.head.repo.full_name == github.repository && !cancelled() }} with: name: Test Results diff --git a/Directory.Build.props b/Directory.Build.props index e7a8422605..bad9b2f9ad 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2026.1.0 + 2026.2.0 Bit.$(MSBuildProjectName) enable @@ -13,6 +13,10 @@ true + + false + false + diff --git a/bitwarden_license/src/Scim/Scim.csproj b/bitwarden_license/src/Scim/Scim.csproj index 7d1ea317b2..d3858e1225 100644 --- a/bitwarden_license/src/Scim/Scim.csproj +++ b/bitwarden_license/src/Scim/Scim.csproj @@ -1,4 +1,5 @@  + bitwarden-Scim diff --git a/bitwarden_license/src/Scim/appsettings.Production.json b/bitwarden_license/src/Scim/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/bitwarden_license/src/Scim/appsettings.Production.json +++ b/bitwarden_license/src/Scim/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/bitwarden_license/src/Sso/Controllers/AccountController.cs b/bitwarden_license/src/Sso/Controllers/AccountController.cs index dde2ac7a46..3d998b6a75 100644 --- a/bitwarden_license/src/Sso/Controllers/AccountController.cs +++ b/bitwarden_license/src/Sso/Controllers/AccountController.cs @@ -2,7 +2,7 @@ using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; @@ -45,7 +45,7 @@ public class AccountController : Controller private readonly ISsoConfigRepository _ssoConfigRepository; private readonly ISsoUserRepository _ssoUserRepository; private readonly IUserRepository _userRepository; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IUserService _userService; private readonly II18nService _i18nService; private readonly UserManager _userManager; @@ -67,7 +67,7 @@ public class AccountController : Controller ISsoConfigRepository ssoConfigRepository, ISsoUserRepository ssoUserRepository, IUserRepository userRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IUserService userService, II18nService i18nService, UserManager userManager, @@ -88,7 +88,7 @@ public class AccountController : Controller _userRepository = userRepository; _ssoConfigRepository = ssoConfigRepository; _ssoUserRepository = ssoUserRepository; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _userService = userService; _i18nService = i18nService; _userManager = userManager; @@ -687,9 +687,8 @@ public class AccountController : Controller await _registerUserCommand.RegisterSSOAutoProvisionedUserAsync(newUser, organization); // If the organization has 2fa policy enabled, make sure to default jit user 2fa to email - var twoFactorPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + var twoFactorPolicy = await _policyQuery.RunAsync(organization.Id, PolicyType.TwoFactorAuthentication); + if (twoFactorPolicy.Enabled) { newUser.SetTwoFactorProviders(new Dictionary { diff --git a/bitwarden_license/src/Sso/Sso.csproj b/bitwarden_license/src/Sso/Sso.csproj index 2a1c14ae5a..709e8c2c4a 100644 --- a/bitwarden_license/src/Sso/Sso.csproj +++ b/bitwarden_license/src/Sso/Sso.csproj @@ -1,4 +1,5 @@  + bitwarden-Sso diff --git a/bitwarden_license/src/Sso/Startup.cs b/bitwarden_license/src/Sso/Startup.cs index a2f363d533..c4c676d51f 100644 --- a/bitwarden_license/src/Sso/Startup.cs +++ b/bitwarden_license/src/Sso/Startup.cs @@ -8,7 +8,6 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Bit.Sso.Utilities; using Duende.IdentityServer.Services; -using Microsoft.IdentityModel.Logging; using Stripe; namespace Bit.Sso; @@ -91,20 +90,15 @@ public class Startup public void Configure( IApplicationBuilder app, - IWebHostEnvironment env, + IWebHostEnvironment environment, IHostApplicationLifetime appLifetime, GlobalSettings globalSettings, ILogger logger) { - if (env.IsDevelopment() || globalSettings.SelfHosted) - { - IdentityModelEventSource.ShowPII = true; - } - // Add general security headers app.UseMiddleware(); - if (!env.IsDevelopment()) + if (!environment.IsDevelopment()) { var uri = new Uri(globalSettings.BaseServiceUri.Sso); app.Use(async (ctx, next) => @@ -120,7 +114,7 @@ public class Startup app.UseForwardedHeaders(globalSettings); } - if (env.IsDevelopment()) + if (environment.IsDevelopment()) { app.UseDeveloperExceptionPage(); app.UseCookiePolicy(); diff --git a/dev/.env.example b/dev/.env.example index f31b5b9eeb..88fbd44036 100644 --- a/dev/.env.example +++ b/dev/.env.example @@ -34,4 +34,5 @@ RABBITMQ_DEFAULT_PASS=SET_A_PASSWORD_HERE_123 # SETUP_AZURITE=yes # RUN_MSSQL_MIGRATIONS=yes # DEV_CERT_PASSWORD=dev_cert_password_here +# DEV_CERT_CONTENTS=base64_encoded_dev_pfx_here (alternative to placing dev.pfx file manually) # INSTALL_STRIPE_CLI=no diff --git a/dev/.gitignore b/dev/.gitignore index 39b657f453..034b002f7c 100644 --- a/dev/.gitignore +++ b/dev/.gitignore @@ -18,3 +18,4 @@ signingkey.jwk # Reverse Proxy Conifg reverse-proxy.conf +*.crt diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index c82da051b4..34cdd3fd2d 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -77,6 +77,7 @@ services: - 4306:3306 environment: MARIADB_USER: maria + MARIADB_PASSWORD: ${MARIADB_ROOT_PASSWORD} MARIADB_DATABASE: vault_dev MARIADB_RANDOM_ROOT_PASSWORD: "true" volumes: diff --git a/dev/secrets.json.example b/dev/secrets.json.example index 0d4213aec1..7bf753e938 100644 --- a/dev/secrets.json.example +++ b/dev/secrets.json.example @@ -39,6 +39,14 @@ }, "licenseDirectory": "", "enableNewDeviceVerification": true, - "enableEmailVerification": true + "enableEmailVerification": true, + "communication": { + "bootstrap": "none", + "ssoCookieVendor": { + "idpLoginUrl": "", + "cookieName": "", + "cookieDomain": "" + } + } } } diff --git a/dev/setup_secrets.ps1 b/dev/setup_secrets.ps1 index 5013ca8bac..a41890bc46 100755 --- a/dev/setup_secrets.ps1 +++ b/dev/setup_secrets.ps1 @@ -28,6 +28,7 @@ $projects = @{ Scim = "../bitwarden_license/src/Scim" IntegrationTests = "../test/Infrastructure.IntegrationTest" SeederApi = "../util/SeederApi" + SeederUtility = "../util/DbSeederUtility" } foreach ($key in $projects.keys) { diff --git a/dev/verify_migrations.ps1 b/dev/verify_migrations.ps1 index ad0d34cef1..ce1754e684 100644 --- a/dev/verify_migrations.ps1 +++ b/dev/verify_migrations.ps1 @@ -5,12 +5,19 @@ Validates that new database migration files follow naming conventions and chronological order. .DESCRIPTION - This script validates migration files in util/Migrator/DbScripts/ to ensure: + This script validates migration files to ensure: + + For SQL migrations in util/Migrator/DbScripts/: 1. New migrations follow the naming format: YYYY-MM-DD_NN_Description.sql 2. New migrations are chronologically ordered (filename sorts after existing migrations) 3. Dates use leading zeros (e.g., 2025-01-05, not 2025-1-5) 4. A 2-digit sequence number is included (e.g., _00, _01) + For Entity Framework migrations in util/MySqlMigrations, util/PostgresMigrations, util/SqliteMigrations: + 1. New migrations follow the naming format: YYYYMMDDHHMMSS_Description.cs + 2. Each migration has both .cs and .Designer.cs files + 3. New migrations are chronologically ordered (timestamp sorts after existing migrations) + .PARAMETER BaseRef The base git reference to compare against (e.g., 'main', 'HEAD~1') @@ -58,75 +65,288 @@ $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$migrationPath/" # Find added migrations $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } +$sqlValidationFailed = $false + if ($addedMigrations.Count -eq 0) { - Write-Host "No new migration files added." - exit 0 + Write-Host "No new SQL migration files added." + Write-Host "" +} +else { + Write-Host "New SQL migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous SQL migrations found (initial commit?). Skipping chronological validation." + Write-Host "" + } + else { + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last SQL migration in base reference: $lastBaseMigration" + Write-Host "" + + # Required format regex: YYYY-MM-DD_NN_Description.sql + $formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Validate NEW migration filename format + if ($migrationName -notmatch $formatRegex) { + Write-Host "ERROR: Migration '$migrationName' does not match required format" + Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" + Write-Host " - YYYY: 4-digit year" + Write-Host " - MM: 2-digit month with leading zero (01-12)" + Write-Host " - DD: 2-digit day with leading zero (01-31)" + Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" + Write-Host "Example: 2025-01-15_00_MyMigration.sql" + $sqlValidationFailed = $true + continue + } + + # Compare migration name with last base migration (using ordinal string comparison) + if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" + $sqlValidationFailed = $true + } + else { + Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + } + } + + Write-Host "" + } + + if ($sqlValidationFailed) { + Write-Host "FAILED: One or more SQL migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new SQL migration files must:" + Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" + Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" + Write-Host " 4. Have a filename that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" + Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 3. Ensure the date is after $lastBaseMigration" + Write-Host "" + Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + } + else { + Write-Host "SUCCESS: All new SQL migrations are correctly named and in chronological order" + } + + Write-Host "" } -Write-Host "New migration files detected:" -$addedMigrations | ForEach-Object { Write-Host " $_" } +# =========================================================================================== +# Validate Entity Framework Migrations +# =========================================================================================== + +Write-Host "===================================================================" +Write-Host "Validating Entity Framework Migrations" +Write-Host "===================================================================" Write-Host "" -# Get the last migration from base reference -if ($baseMigrations.Count -eq 0) { - Write-Host "No previous migrations found (initial commit?). Skipping validation." - exit 0 -} +$efMigrationPaths = @( + @{ Path = "util/MySqlMigrations/Migrations"; Name = "MySQL" }, + @{ Path = "util/PostgresMigrations/Migrations"; Name = "Postgres" }, + @{ Path = "util/SqliteMigrations/Migrations"; Name = "SQLite" } +) -$lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) -Write-Host "Last migration in base reference: $lastBaseMigration" -Write-Host "" +$efValidationFailed = $false -# Required format regex: YYYY-MM-DD_NN_Description.sql -$formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' +foreach ($migrationPathInfo in $efMigrationPaths) { + $efPath = $migrationPathInfo.Path + $dbName = $migrationPathInfo.Name -$validationFailed = $false + Write-Host "-------------------------------------------------------------------" + Write-Host "Checking $dbName EF migrations in $efPath" + Write-Host "-------------------------------------------------------------------" + Write-Host "" -foreach ($migration in $addedMigrations) { - $migrationName = Split-Path -Leaf $migration + # Get list of migrations from base reference + try { + $baseMigrations = git ls-tree -r --name-only $BaseRef -- "$efPath/" 2>$null | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } + } + catch { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } - # Validate NEW migration filename format - if ($migrationName -notmatch $formatRegex) { - Write-Host "ERROR: Migration '$migrationName' does not match required format" - Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" - Write-Host " - YYYY: 4-digit year" - Write-Host " - MM: 2-digit month with leading zero (01-12)" - Write-Host " - DD: 2-digit day with leading zero (01-31)" - Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" - Write-Host "Example: 2025-01-15_00_MyMigration.sql" - $validationFailed = $true + # Get list of migrations from current reference + $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$efPath/" | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + + # Find added migrations + $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } + + if ($addedMigrations.Count -eq 0) { + Write-Host "No new $dbName EF migration files added." + Write-Host "" continue } - # Compare migration name with last base migration (using ordinal string comparison) - if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { - Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" - $validationFailed = $true + Write-Host "New $dbName EF migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous $dbName migrations found. Skipping chronological validation." + Write-Host "" } else { - Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last $dbName migration in base reference: $lastBaseMigration" + Write-Host "" } + + # Required format regex: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs + $efFormatRegex = '^[0-9]{14}_.+\.cs$' + + # Group migrations by base name (without .Designer.cs suffix) + $migrationGroups = @{} + $unmatchedFiles = @() + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Extract base name (remove .Designer.cs or .cs) + if ($migrationName -match '^([0-9]{14}_.+?)(?:\.Designer)?\.cs$') { + $baseName = $matches[1] + if (-not $migrationGroups.ContainsKey($baseName)) { + $migrationGroups[$baseName] = @() + } + $migrationGroups[$baseName] += $migrationName + } + else { + # Track files that don't match the expected pattern + $unmatchedFiles += $migrationName + } + } + + # Flag any files that don't match the expected pattern + if ($unmatchedFiles.Count -gt 0) { + Write-Host "ERROR: The following migration files do not match the required format:" + foreach ($unmatchedFile in $unmatchedFiles) { + Write-Host " - $unmatchedFile" + } + Write-Host "" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host " - Description: Descriptive name using PascalCase" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" + Write-Host "" + $efValidationFailed = $true + } + + foreach ($baseName in $migrationGroups.Keys | Sort-Object) { + $files = $migrationGroups[$baseName] + + # Validate format + $mainFile = "$baseName.cs" + $designerFile = "$baseName.Designer.cs" + + if ($mainFile -notmatch $efFormatRegex) { + Write-Host "ERROR: Migration '$mainFile' does not match required format" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host "Example: 20250115120000_AddNewFeature.cs" + $efValidationFailed = $true + continue + } + + # Check that both .cs and .Designer.cs files exist + $hasCsFile = $files -contains $mainFile + $hasDesignerFile = $files -contains $designerFile + + if (-not $hasCsFile) { + Write-Host "ERROR: Missing main migration file: $mainFile" + $efValidationFailed = $true + } + + if (-not $hasDesignerFile) { + Write-Host "ERROR: Missing designer file: $designerFile" + Write-Host "Each EF migration must have both a .cs and .Designer.cs file" + $efValidationFailed = $true + } + + if (-not $hasCsFile -or -not $hasDesignerFile) { + continue + } + + # Compare migration timestamp with last base migration (using ordinal string comparison) + if ($baseMigrations.Count -gt 0) { + if ([string]::CompareOrdinal($mainFile, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$mainFile' is not chronologically after '$lastBaseMigration'" + $efValidationFailed = $true + } + else { + Write-Host "OK: '$mainFile' is chronologically after '$lastBaseMigration'" + } + } + else { + Write-Host "OK: '$mainFile' (no previous migrations to compare)" + } + } + + Write-Host "" +} + +if ($efValidationFailed) { + Write-Host "FAILED: One or more EF migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new EF migration files must:" + Write-Host " 1. Follow the naming format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 2. Include both .cs and .Designer.cs files" + Write-Host " 3. Have a timestamp that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in the respective Migrations directory" + Write-Host " 2. Ensure both .cs and .Designer.cs files exist" + Write-Host " 3. Rename to follow format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 4. Ensure the timestamp is after the last migration" + Write-Host "" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" +} +else { + Write-Host "SUCCESS: All new EF migrations are correctly named and in chronological order" } Write-Host "" +Write-Host "===================================================================" +Write-Host "Validation Summary" +Write-Host "===================================================================" + +if ($sqlValidationFailed -or $efValidationFailed) { + if ($sqlValidationFailed) { + Write-Host "❌ SQL migrations validation FAILED" + } + else { + Write-Host "✓ SQL migrations validation PASSED" + } + + if ($efValidationFailed) { + Write-Host "❌ EF migrations validation FAILED" + } + else { + Write-Host "✓ EF migrations validation PASSED" + } -if ($validationFailed) { - Write-Host "FAILED: One or more migrations are incorrectly named or not in chronological order" Write-Host "" - Write-Host "All new migration files must:" - Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" - Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" - Write-Host " 4. Have a filename that sorts after the last migration in base" - Write-Host "" - Write-Host "To fix this issue:" - Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" - Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 3. Ensure the date is after $lastBaseMigration" - Write-Host "" - Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + Write-Host "OVERALL RESULT: FAILED" exit 1 } - -Write-Host "SUCCESS: All new migrations are correctly named and in chronological order" -exit 0 +else { + Write-Host "✓ SQL migrations validation PASSED" + Write-Host "✓ EF migrations validation PASSED" + Write-Host "" + Write-Host "OVERALL RESULT: SUCCESS" + exit 0 +} diff --git a/global.json b/global.json index 4cbe3f083a..970250aec9 100644 --- a/global.json +++ b/global.json @@ -6,6 +6,6 @@ "msbuild-sdks": { "Microsoft.Build.Traversal": "4.1.0", "Microsoft.Build.Sql": "1.0.0", - "Bitwarden.Server.Sdk": "1.2.0" + "Bitwarden.Server.Sdk": "1.4.0" } } diff --git a/src/Admin/Admin.csproj b/src/Admin/Admin.csproj index b815ddea82..5733589466 100644 --- a/src/Admin/Admin.csproj +++ b/src/Admin/Admin.csproj @@ -1,4 +1,5 @@ + bitwarden-Admin diff --git a/src/Admin/Controllers/UsersController.cs b/src/Admin/Controllers/UsersController.cs index f42b22b098..eb7f1e8c04 100644 --- a/src/Admin/Controllers/UsersController.cs +++ b/src/Admin/Controllers/UsersController.cs @@ -86,7 +86,7 @@ public class UsersController : Controller return RedirectToAction("Index"); } - var ciphers = await _cipherRepository.GetManyByUserIdAsync(id); + var ciphers = await _cipherRepository.GetManyByUserIdAsync(id, withOrganizations: false); var isTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(user); var verifiedDomain = await _userService.IsClaimedByAnyOrganizationAsync(user.Id); diff --git a/src/Admin/Program.cs b/src/Admin/Program.cs index 006a8223b2..80a1ae058c 100644 --- a/src/Admin/Program.cs +++ b/src/Admin/Program.cs @@ -8,7 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.ConfigureKestrel(o => diff --git a/src/Admin/appsettings.Production.json b/src/Admin/appsettings.Production.json index 9f797f3111..1d852abfed 100644 --- a/src/Admin/appsettings.Production.json +++ b/src/Admin/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 90d02a46a1..ecea7caa96 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -57,7 +57,7 @@ public class OrganizationUsersController : BaseAdminConsoleController private readonly ICollectionRepository _collectionRepository; private readonly IGroupRepository _groupRepository; private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly ICurrentContext _currentContext; private readonly ICountNewSmSeatsRequiredQuery _countNewSmSeatsRequiredQuery; private readonly IUpdateSecretsManagerSubscriptionCommand _updateSecretsManagerSubscriptionCommand; @@ -90,7 +90,7 @@ public class OrganizationUsersController : BaseAdminConsoleController ICollectionRepository collectionRepository, IGroupRepository groupRepository, IUserService userService, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, ICurrentContext currentContext, ICountNewSmSeatsRequiredQuery countNewSmSeatsRequiredQuery, IUpdateSecretsManagerSubscriptionCommand updateSecretsManagerSubscriptionCommand, @@ -123,7 +123,7 @@ public class OrganizationUsersController : BaseAdminConsoleController _collectionRepository = collectionRepository; _groupRepository = groupRepository; _userService = userService; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _currentContext = currentContext; _countNewSmSeatsRequiredQuery = countNewSmSeatsRequiredQuery; _updateSecretsManagerSubscriptionCommand = updateSecretsManagerSubscriptionCommand; @@ -155,7 +155,7 @@ public class OrganizationUsersController : BaseAdminConsoleController [Authorize] public async Task Get(Guid orgId, Guid id, bool includeGroups = false) { - var (organizationUser, collections) = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); + var (organizationUser, collections) = await _organizationUserRepository.GetDetailsByIdWithSharedCollectionsAsync(id); if (organizationUser == null || organizationUser.OrganizationId != orgId) { throw new NotFoundException(); @@ -287,14 +287,7 @@ public class OrganizationUsersController : BaseAdminConsoleController var userId = _userService.GetProperUserId(User); IEnumerable> result; - if (_featureService.IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud)) - { - result = await _bulkResendOrganizationInvitesCommand.BulkResendInvitesAsync(orgId, userId.Value, model.Ids); - } - else - { - result = await _organizationService.ResendInvitesAsync(orgId, userId.Value, model.Ids); - } + result = await _bulkResendOrganizationInvitesCommand.BulkResendInvitesAsync(orgId, userId.Value, model.Ids); return new ListResponseModel( result.Select(t => new OrganizationUserBulkResponseModel(t.Item1.Id, t.Item2))); @@ -331,6 +324,12 @@ public class OrganizationUsersController : BaseAdminConsoleController throw new UnauthorizedAccessException(); } + var organizationUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + if (organizationUser == null || organizationUser.OrganizationId != orgId) + { + throw new NotFoundException("Organization user mismatch"); + } + var useMasterPasswordPolicy = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) ? (await _policyRequirementQuery.GetAsync(user.Id)).AutoEnrollEnabled(orgId) : await ShouldHandleResetPasswordAsync(orgId); @@ -357,10 +356,9 @@ public class OrganizationUsersController : BaseAdminConsoleController return false; } - var masterPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - var useMasterPasswordPolicy = masterPasswordPolicy != null && - masterPasswordPolicy.Enabled && - masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; + var masterPasswordPolicy = await _policyQuery.RunAsync(orgId, PolicyType.ResetPassword); + var useMasterPasswordPolicy = masterPasswordPolicy.Enabled && + masterPasswordPolicy.GetDataModel().AutoEnrollEnabled; return useMasterPasswordPolicy; } @@ -697,7 +695,16 @@ public class OrganizationUsersController : BaseAdminConsoleController [Authorize] public async Task RestoreAsync(Guid orgId, Guid id) { - await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, userId)); + await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, userId, null)); + } + + + [HttpPut("{id}/restore/vnext")] + [Authorize] + [RequireFeature(FeatureFlagKeys.DefaultUserCollectionRestore)] + public async Task RestoreAsync_vNext(Guid orgId, Guid id, [FromBody] OrganizationUserRestoreRequest request) + { + await RestoreOrRevokeUserAsync(orgId, id, (orgUser, userId) => _restoreOrganizationUserCommand.RestoreUserAsync(orgUser, userId, request.DefaultUserCollectionName)); } [HttpPatch("{id}/restore")] @@ -712,7 +719,9 @@ public class OrganizationUsersController : BaseAdminConsoleController [Authorize] public async Task> BulkRestoreAsync(Guid orgId, [FromBody] OrganizationUserBulkRequestModel model) { - return await RestoreOrRevokeUsersAsync(orgId, model, (orgId, orgUserIds, restoringUserId) => _restoreOrganizationUserCommand.RestoreUsersAsync(orgId, orgUserIds, restoringUserId, _userService)); + return await RestoreOrRevokeUsersAsync(orgId, model, + (orgId, orgUserIds, restoringUserId) => _restoreOrganizationUserCommand.RestoreUsersAsync(orgId, orgUserIds, + restoringUserId, _userService, model.DefaultUserCollectionName)); } [HttpPatch("restore")] diff --git a/src/Api/AdminConsole/Controllers/OrganizationsController.cs b/src/Api/AdminConsole/Controllers/OrganizationsController.cs index 100cd7caf6..a6de8c521f 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationsController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationsController.cs @@ -48,7 +48,7 @@ public class OrganizationsController : Controller { private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IOrganizationService _organizationService; private readonly IUserService _userService; private readonly ICurrentContext _currentContext; @@ -74,7 +74,7 @@ public class OrganizationsController : Controller public OrganizationsController( IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IOrganizationService organizationService, IUserService userService, ICurrentContext currentContext, @@ -99,7 +99,7 @@ public class OrganizationsController : Controller { _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _organizationService = organizationService; _userService = userService; _currentContext = currentContext; @@ -183,15 +183,14 @@ public class OrganizationsController : Controller return new OrganizationAutoEnrollStatusResponseModel(organization.Id, resetPasswordPolicyRequirement.AutoEnrollEnabled(organization.Id)); } - var resetPasswordPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) + var resetPasswordPolicy = await _policyQuery.RunAsync(organization.Id, PolicyType.ResetPassword); + if (!resetPasswordPolicy.Enabled || resetPasswordPolicy.Data == null) { return new OrganizationAutoEnrollStatusResponseModel(organization.Id, false); } var data = JsonSerializer.Deserialize(resetPasswordPolicy.Data, JsonHelpers.IgnoreCase); return new OrganizationAutoEnrollStatusResponseModel(organization.Id, data?.AutoEnrollEnabled ?? false); - } [HttpPost("")] diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index bce0332d67..fe3600c3dd 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -7,7 +7,6 @@ using Bit.Api.AdminConsole.Models.Request; using Bit.Api.AdminConsole.Models.Response.Helpers; using Bit.Api.AdminConsole.Models.Response.Organizations; using Bit.Api.Models.Response; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; @@ -43,6 +42,7 @@ public class PoliciesController : Controller private readonly IUserService _userService; private readonly ISavePolicyCommand _savePolicyCommand; private readonly IVNextSavePolicyCommand _vNextSavePolicyCommand; + private readonly IPolicyQuery _policyQuery; public PoliciesController(IPolicyRepository policyRepository, IOrganizationUserRepository organizationUserRepository, @@ -54,7 +54,8 @@ public class PoliciesController : Controller IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, IOrganizationRepository organizationRepository, ISavePolicyCommand savePolicyCommand, - IVNextSavePolicyCommand vNextSavePolicyCommand) + IVNextSavePolicyCommand vNextSavePolicyCommand, + IPolicyQuery policyQuery) { _policyRepository = policyRepository; _organizationUserRepository = organizationUserRepository; @@ -68,27 +69,24 @@ public class PoliciesController : Controller _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; _savePolicyCommand = savePolicyCommand; _vNextSavePolicyCommand = vNextSavePolicyCommand; + _policyQuery = policyQuery; } [HttpGet("{type}")] - public async Task Get(Guid orgId, int type) + public async Task Get(Guid orgId, PolicyType type) { if (!await _currentContext.ManagePolicies(orgId)) { throw new NotFoundException(); } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(orgId, (PolicyType)type); - if (policy == null) - { - return new PolicyDetailResponseModel(new Policy { Type = (PolicyType)type }); - } + var policy = await _policyQuery.RunAsync(orgId, type); if (policy.Type is PolicyType.SingleOrg) { - return await policy.GetSingleOrgPolicyDetailResponseAsync(_organizationHasVerifiedDomainsQuery); + return await policy.GetSingleOrgPolicyStatusResponseAsync(_organizationHasVerifiedDomainsQuery); } - return new PolicyDetailResponseModel(policy); + return new PolicyStatusResponseModel(policy); } [HttpGet("")] diff --git a/src/Api/AdminConsole/Models/Request/OrganizationDomainRequestModel.cs b/src/Api/AdminConsole/Models/Request/OrganizationDomainRequestModel.cs index 46b253da31..3a2ada719f 100644 --- a/src/Api/AdminConsole/Models/Request/OrganizationDomainRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/OrganizationDomainRequestModel.cs @@ -2,11 +2,13 @@ #nullable disable using System.ComponentModel.DataAnnotations; +using Bit.Core.Utilities; namespace Bit.Api.AdminConsole.Models.Request; public class OrganizationDomainRequestModel { [Required] + [DomainNameValidator] public string DomainName { get; set; } } diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs index b7a4db3acd..06fe654b73 100644 --- a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRequestModels.cs @@ -116,12 +116,17 @@ public class OrganizationUserResetPasswordEnrollmentRequestModel public string ResetPasswordKey { get; set; } public string MasterPasswordHash { get; set; } } - +#nullable enable public class OrganizationUserBulkRequestModel { [Required, MinLength(1)] - public IEnumerable Ids { get; set; } + public IEnumerable Ids { get; set; } = new List(); + + [EncryptedString] + [EncryptedStringLength(1000)] + public string? DefaultUserCollectionName { get; set; } } +#nullable disable public class ResetPasswordWithOrgIdRequestModel : OrganizationUserResetPasswordEnrollmentRequestModel { diff --git a/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRestoreRequest.cs b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRestoreRequest.cs new file mode 100644 index 0000000000..ff5f877b3a --- /dev/null +++ b/src/Api/AdminConsole/Models/Request/Organizations/OrganizationUserRestoreRequest.cs @@ -0,0 +1,13 @@ +using Bit.Core.Utilities; + +namespace Bit.Api.AdminConsole.Models.Request.Organizations; + +public class OrganizationUserRestoreRequest +{ + /// + /// This is the encrypted default collection name to be used for restored users if required + /// + [EncryptedString] + [EncryptedStringLength(1000)] + public string? DefaultUserCollectionName { get; set; } +} diff --git a/src/Api/AdminConsole/Models/Response/Helpers/PolicyDetailResponses.cs b/src/Api/AdminConsole/Models/Response/Helpers/PolicyStatusResponses.cs similarity index 66% rename from src/Api/AdminConsole/Models/Response/Helpers/PolicyDetailResponses.cs rename to src/Api/AdminConsole/Models/Response/Helpers/PolicyStatusResponses.cs index dded6a4c89..da08cdef0f 100644 --- a/src/Api/AdminConsole/Models/Response/Helpers/PolicyDetailResponses.cs +++ b/src/Api/AdminConsole/Models/Response/Helpers/PolicyStatusResponses.cs @@ -1,19 +1,21 @@ using Bit.Api.AdminConsole.Models.Response.Organizations; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; namespace Bit.Api.AdminConsole.Models.Response.Helpers; -public static class PolicyDetailResponses +public static class PolicyStatusResponses { - public static async Task GetSingleOrgPolicyDetailResponseAsync(this Policy policy, IOrganizationHasVerifiedDomainsQuery hasVerifiedDomainsQuery) + public static async Task GetSingleOrgPolicyStatusResponseAsync( + this PolicyStatus policy, IOrganizationHasVerifiedDomainsQuery hasVerifiedDomainsQuery) { if (policy.Type is not PolicyType.SingleOrg) { throw new ArgumentException($"'{nameof(policy)}' must be of type '{nameof(PolicyType.SingleOrg)}'.", nameof(policy)); } - return new PolicyDetailResponseModel(policy, await CanToggleState()); + + return new PolicyStatusResponseModel(policy, await CanToggleState()); async Task CanToggleState() { @@ -25,5 +27,4 @@ public static class PolicyDetailResponses return !policy.Enabled; } } - } diff --git a/src/Api/AdminConsole/Models/Response/Organizations/PolicyDetailResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/PolicyDetailResponseModel.cs deleted file mode 100644 index cb5560e689..0000000000 --- a/src/Api/AdminConsole/Models/Response/Organizations/PolicyDetailResponseModel.cs +++ /dev/null @@ -1,20 +0,0 @@ -using Bit.Core.AdminConsole.Entities; - -namespace Bit.Api.AdminConsole.Models.Response.Organizations; - -public class PolicyDetailResponseModel : PolicyResponseModel -{ - public PolicyDetailResponseModel(Policy policy, string obj = "policy") : base(policy, obj) - { - } - - public PolicyDetailResponseModel(Policy policy, bool canToggleState) : base(policy) - { - CanToggleState = canToggleState; - } - - /// - /// Indicates whether the Policy can be enabled/disabled - /// - public bool CanToggleState { get; set; } = true; -} diff --git a/src/Api/AdminConsole/Models/Response/Organizations/PolicyStatusResponseModel.cs b/src/Api/AdminConsole/Models/Response/Organizations/PolicyStatusResponseModel.cs new file mode 100644 index 0000000000..8c93302a17 --- /dev/null +++ b/src/Api/AdminConsole/Models/Response/Organizations/PolicyStatusResponseModel.cs @@ -0,0 +1,33 @@ +using System.Text.Json; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.Models.Api; + +namespace Bit.Api.AdminConsole.Models.Response.Organizations; + +public class PolicyStatusResponseModel : ResponseModel +{ + public PolicyStatusResponseModel(PolicyStatus policy, bool canToggleState = true) : base("policy") + { + OrganizationId = policy.OrganizationId; + Type = policy.Type; + + if (!string.IsNullOrWhiteSpace(policy.Data)) + { + Data = JsonSerializer.Deserialize>(policy.Data) ?? new(); + } + + Enabled = policy.Enabled; + CanToggleState = canToggleState; + } + + public Guid OrganizationId { get; init; } + public PolicyType Type { get; init; } + public Dictionary Data { get; init; } = new(); + public bool Enabled { get; init; } + + /// + /// Indicates whether the Policy can be enabled/disabled + /// + public bool CanToggleState { get; init; } +} diff --git a/src/Api/AdminConsole/Public/Controllers/MembersController.cs b/src/Api/AdminConsole/Public/Controllers/MembersController.cs index 58e5db18c2..e312f009c9 100644 --- a/src/Api/AdminConsole/Public/Controllers/MembersController.cs +++ b/src/Api/AdminConsole/Public/Controllers/MembersController.cs @@ -2,12 +2,16 @@ using Bit.Api.AdminConsole.Public.Models.Request; using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RestoreUser.v1; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.RevokeUser.v2; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Services; using Bit.Core.Context; +using Bit.Core.Enums; using Bit.Core.Repositories; using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; @@ -30,6 +34,8 @@ public class MembersController : Controller private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IResendOrganizationInviteCommand _resendOrganizationInviteCommand; + private readonly IRevokeOrganizationUserCommand _revokeOrganizationUserCommandV2; + private readonly IRestoreOrganizationUserCommand _restoreOrganizationUserCommand; public MembersController( IOrganizationUserRepository organizationUserRepository, @@ -42,7 +48,9 @@ public class MembersController : Controller IOrganizationRepository organizationRepository, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, IRemoveOrganizationUserCommand removeOrganizationUserCommand, - IResendOrganizationInviteCommand resendOrganizationInviteCommand) + IResendOrganizationInviteCommand resendOrganizationInviteCommand, + IRevokeOrganizationUserCommand revokeOrganizationUserCommandV2, + IRestoreOrganizationUserCommand restoreOrganizationUserCommand) { _organizationUserRepository = organizationUserRepository; _groupRepository = groupRepository; @@ -55,6 +63,8 @@ public class MembersController : Controller _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _removeOrganizationUserCommand = removeOrganizationUserCommand; _resendOrganizationInviteCommand = resendOrganizationInviteCommand; + _revokeOrganizationUserCommandV2 = revokeOrganizationUserCommandV2; + _restoreOrganizationUserCommand = restoreOrganizationUserCommand; } /// @@ -70,7 +80,7 @@ public class MembersController : Controller [ProducesResponseType((int)HttpStatusCode.NotFound)] public async Task Get(Guid id) { - var (orgUser, collections) = await _organizationUserRepository.GetDetailsByIdWithCollectionsAsync(id); + var (orgUser, collections) = await _organizationUserRepository.GetDetailsByIdWithSharedCollectionsAsync(id); if (orgUser == null || orgUser.OrganizationId != _currentContext.OrganizationId) { return new NotFoundResult(); @@ -113,7 +123,7 @@ public class MembersController : Controller [ProducesResponseType(typeof(ListResponseModel), (int)HttpStatusCode.OK)] public async Task List() { - var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId!.Value, includeCollections: true); + var organizationUserUserDetails = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(_currentContext.OrganizationId!.Value, includeSharedCollections: true); var orgUsersTwoFactorIsEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(organizationUserUserDetails); var memberResponses = organizationUserUserDetails.Select(u => @@ -258,4 +268,59 @@ public class MembersController : Controller await _resendOrganizationInviteCommand.ResendInviteAsync(_currentContext.OrganizationId!.Value, null, id); return new OkResult(); } + + /// + /// Revoke a member's access to an organization. + /// + /// The ID of the member to be revoked. + [HttpPost("{id}/revoke")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Revoke(Guid id) + { + var organizationUser = await _organizationUserRepository.GetByIdAsync(id); + if (organizationUser == null || organizationUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + + var request = new RevokeOrganizationUsersRequest( + _currentContext.OrganizationId!.Value, + [id], + new SystemUser(EventSystemUser.PublicApi) + ); + + var results = await _revokeOrganizationUserCommandV2.RevokeUsersAsync(request); + var result = results.Single(); + + return result.Result.Match( + error => new BadRequestObjectResult(new ErrorResponseModel(error.Message)), + _ => new OkResult() + ); + } + + /// + /// Restore a member. + /// + /// + /// Restores a previously revoked member of the organization. + /// + /// The identifier of the member to be restored. + [HttpPost("{id}/restore")] + [ProducesResponseType((int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] + [ProducesResponseType((int)HttpStatusCode.NotFound)] + public async Task Restore(Guid id) + { + var organizationUser = await _organizationUserRepository.GetByIdAsync(id); + if (organizationUser == null || organizationUser.OrganizationId != _currentContext.OrganizationId) + { + return new NotFoundResult(); + } + + await _restoreOrganizationUserCommand.RestoreUserAsync(organizationUser, EventSystemUser.PublicApi); + + return new OkResult(); + } } diff --git a/src/Api/Api.csproj b/src/Api/Api.csproj index d25b989d11..07dc066227 100644 --- a/src/Api/Api.csproj +++ b/src/Api/Api.csproj @@ -1,4 +1,6 @@  + + bitwarden-Api false @@ -36,7 +38,7 @@ - + diff --git a/src/Api/Auth/Controllers/EmergencyAccessController.cs b/src/Api/Auth/Controllers/EmergencyAccessController.cs index 016cd82fe2..bd87e82c8a 100644 --- a/src/Api/Auth/Controllers/EmergencyAccessController.cs +++ b/src/Api/Auth/Controllers/EmergencyAccessController.cs @@ -7,7 +7,7 @@ using Bit.Api.Auth.Models.Request; using Bit.Api.Auth.Models.Response; using Bit.Api.Models.Response; using Bit.Api.Vault.Models.Response; -using Bit.Core.Auth.Services; +using Bit.Core.Auth.UserFeatures.EmergencyAccess; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; diff --git a/src/Api/Auth/Jobs/EmergencyAccessNotificationJob.cs b/src/Api/Auth/Jobs/EmergencyAccessNotificationJob.cs index c67cb9db3f..f58eaafaab 100644 --- a/src/Api/Auth/Jobs/EmergencyAccessNotificationJob.cs +++ b/src/Api/Auth/Jobs/EmergencyAccessNotificationJob.cs @@ -1,7 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using Bit.Core.Auth.Services; +using Bit.Core.Auth.UserFeatures.EmergencyAccess; using Bit.Core.Jobs; using Quartz; diff --git a/src/Api/Auth/Jobs/EmergencyAccessTimeoutJob.cs b/src/Api/Auth/Jobs/EmergencyAccessTimeoutJob.cs index f23774f060..63b861d920 100644 --- a/src/Api/Auth/Jobs/EmergencyAccessTimeoutJob.cs +++ b/src/Api/Auth/Jobs/EmergencyAccessTimeoutJob.cs @@ -1,7 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable -using Bit.Core.Auth.Services; +using Bit.Core.Auth.UserFeatures.EmergencyAccess; using Bit.Core.Jobs; using Quartz; diff --git a/src/Api/Auth/Models/Request/Accounts/PasswordRequestModel.cs b/src/Api/Auth/Models/Request/Accounts/PasswordRequestModel.cs index 8fa51e9f34..ab8c727852 100644 --- a/src/Api/Auth/Models/Request/Accounts/PasswordRequestModel.cs +++ b/src/Api/Auth/Models/Request/Accounts/PasswordRequestModel.cs @@ -1,7 +1,5 @@ -#nullable enable - -using System.ComponentModel.DataAnnotations; -using Bit.Api.KeyManagement.Models.Requests; +using System.ComponentModel.DataAnnotations; +using Bit.Core.KeyManagement.Models.Api.Request; namespace Bit.Api.Auth.Models.Request.Accounts; diff --git a/src/Api/Auth/Models/Request/Accounts/SetInitialPasswordRequestModel.cs b/src/Api/Auth/Models/Request/Accounts/SetInitialPasswordRequestModel.cs index 55ffdca94b..37a7901fee 100644 --- a/src/Api/Auth/Models/Request/Accounts/SetInitialPasswordRequestModel.cs +++ b/src/Api/Auth/Models/Request/Accounts/SetInitialPasswordRequestModel.cs @@ -1,5 +1,4 @@ using System.ComponentModel.DataAnnotations; -using Bit.Api.KeyManagement.Models.Requests; using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; diff --git a/src/Api/Auth/Models/Response/EmergencyAccessResponseModel.cs b/src/Api/Auth/Models/Response/EmergencyAccessResponseModel.cs index 640c9bb3e0..dff766da12 100644 --- a/src/Api/Auth/Models/Response/EmergencyAccessResponseModel.cs +++ b/src/Api/Auth/Models/Response/EmergencyAccessResponseModel.cs @@ -112,6 +112,7 @@ public class EmergencyAccessTakeoverResponseModel : ResponseModel KdfIterations = grantor.KdfIterations; KdfMemory = grantor.KdfMemory; KdfParallelism = grantor.KdfParallelism; + Salt = grantor.GetMasterPasswordSalt(); } public int KdfIterations { get; private set; } @@ -119,6 +120,7 @@ public class EmergencyAccessTakeoverResponseModel : ResponseModel public int? KdfParallelism { get; private set; } public KdfType Kdf { get; private set; } public string KeyEncrypted { get; private set; } + public string Salt { get; private set; } } public class EmergencyAccessViewResponseModel : ResponseModel diff --git a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs index 7ca85d52a8..8a1467dfa2 100644 --- a/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs +++ b/src/Api/Billing/Controllers/OrganizationSponsorshipsController.cs @@ -6,7 +6,7 @@ using Bit.Api.Models.Response; using Bit.Api.Models.Response.Organizations; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationConnections.Interfaces; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Exceptions; @@ -38,7 +38,7 @@ public class OrganizationSponsorshipsController : Controller private readonly ICloudSyncSponsorshipsCommand _syncSponsorshipsCommand; private readonly ICurrentContext _currentContext; private readonly IUserService _userService; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IFeatureService _featureService; public OrganizationSponsorshipsController( @@ -55,7 +55,7 @@ public class OrganizationSponsorshipsController : Controller ICloudSyncSponsorshipsCommand syncSponsorshipsCommand, IUserService userService, ICurrentContext currentContext, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IFeatureService featureService) { _organizationSponsorshipRepository = organizationSponsorshipRepository; @@ -71,7 +71,7 @@ public class OrganizationSponsorshipsController : Controller _syncSponsorshipsCommand = syncSponsorshipsCommand; _userService = userService; _currentContext = currentContext; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _featureService = featureService; } @@ -81,10 +81,10 @@ public class OrganizationSponsorshipsController : Controller public async Task CreateSponsorship(Guid sponsoringOrgId, [FromBody] OrganizationSponsorshipCreateRequestModel model) { var sponsoringOrg = await _organizationRepository.GetByIdAsync(sponsoringOrgId); - var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId, + var freeFamiliesSponsorshipPolicy = await _policyQuery.RunAsync(sponsoringOrgId, PolicyType.FreeFamiliesSponsorshipPolicy); - if (freeFamiliesSponsorshipPolicy?.Enabled == true) + if (freeFamiliesSponsorshipPolicy.Enabled) { throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator."); } @@ -108,10 +108,10 @@ public class OrganizationSponsorshipsController : Controller [SelfHosted(NotSelfHostedOnly = true)] public async Task ResendSponsorshipOffer(Guid sponsoringOrgId, [FromQuery] string sponsoredFriendlyName) { - var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsoringOrgId, + var freeFamiliesSponsorshipPolicy = await _policyQuery.RunAsync(sponsoringOrgId, PolicyType.FreeFamiliesSponsorshipPolicy); - if (freeFamiliesSponsorshipPolicy?.Enabled == true) + if (freeFamiliesSponsorshipPolicy.Enabled) { throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator."); } @@ -138,9 +138,9 @@ public class OrganizationSponsorshipsController : Controller var (isValid, sponsorship) = await _validateRedemptionTokenCommand.ValidateRedemptionTokenAsync(sponsorshipToken, (await CurrentUser).Email); if (isValid && sponsorship.SponsoringOrganizationId.HasValue) { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(sponsorship.SponsoringOrganizationId.Value, + var policy = await _policyQuery.RunAsync(sponsorship.SponsoringOrganizationId.Value, PolicyType.FreeFamiliesSponsorshipPolicy); - isFreeFamilyPolicyEnabled = policy?.Enabled ?? false; + isFreeFamilyPolicyEnabled = policy.Enabled; } var response = PreValidateSponsorshipResponseModel.From(isValid, isFreeFamilyPolicyEnabled); @@ -165,10 +165,10 @@ public class OrganizationSponsorshipsController : Controller throw new BadRequestException("Can only redeem sponsorship for an organization you own."); } - var freeFamiliesSponsorshipPolicy = await _policyRepository.GetByOrganizationIdTypeAsync( + var freeFamiliesSponsorshipPolicy = await _policyQuery.RunAsync( model.SponsoredOrganizationId, PolicyType.FreeFamiliesSponsorshipPolicy); - if (freeFamiliesSponsorshipPolicy?.Enabled == true) + if (freeFamiliesSponsorshipPolicy.Enabled) { throw new BadRequestException("Free Bitwarden Families sponsorship has been disabled by your organization administrator."); } diff --git a/src/Api/Billing/Controllers/TaxController.cs b/src/Api/Billing/Controllers/PreviewInvoiceController.cs similarity index 62% rename from src/Api/Billing/Controllers/TaxController.cs rename to src/Api/Billing/Controllers/PreviewInvoiceController.cs index 4ead414589..c958454618 100644 --- a/src/Api/Billing/Controllers/TaxController.cs +++ b/src/Api/Billing/Controllers/PreviewInvoiceController.cs @@ -1,8 +1,9 @@ using Bit.Api.Billing.Attributes; -using Bit.Api.Billing.Models.Requests.Tax; +using Bit.Api.Billing.Models.Requests.PreviewInvoice; using Bit.Core.AdminConsole.Entities; using Bit.Core.Billing.Organizations.Commands; using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Entities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.ModelBinding; @@ -10,10 +11,11 @@ using Microsoft.AspNetCore.Mvc.ModelBinding; namespace Bit.Api.Billing.Controllers; [Authorize("Application")] -[Route("billing/tax")] -public class TaxController( +[Route("billing/preview-invoice")] +public class PreviewInvoiceController( IPreviewOrganizationTaxCommand previewOrganizationTaxCommand, - IPreviewPremiumTaxCommand previewPremiumTaxCommand) : BaseBillingController + IPreviewPremiumTaxCommand previewPremiumTaxCommand, + IPreviewPremiumUpgradeProrationCommand previewPremiumUpgradeProrationCommand) : BaseBillingController { [HttpPost("organizations/subscriptions/purchase")] public async Task PreviewOrganizationSubscriptionPurchaseTaxAsync( @@ -21,11 +23,7 @@ public class TaxController( { var (purchase, billingAddress) = request.ToDomain(); var result = await previewOrganizationTaxCommand.Run(purchase, billingAddress); - return Handle(result.Map(pair => new - { - pair.Tax, - pair.Total - })); + return Handle(result.Map(pair => new { pair.Tax, pair.Total })); } [HttpPost("organizations/{organizationId:guid}/subscription/plan-change")] @@ -36,11 +34,7 @@ public class TaxController( { var (planChange, billingAddress) = request.ToDomain(); var result = await previewOrganizationTaxCommand.Run(organization, planChange, billingAddress); - return Handle(result.Map(pair => new - { - pair.Tax, - pair.Total - })); + return Handle(result.Map(pair => new { pair.Tax, pair.Total })); } [HttpPut("organizations/{organizationId:guid}/subscription/update")] @@ -51,11 +45,7 @@ public class TaxController( { var update = request.ToDomain(); var result = await previewOrganizationTaxCommand.Run(organization, update); - return Handle(result.Map(pair => new - { - pair.Tax, - pair.Total - })); + return Handle(result.Map(pair => new { pair.Tax, pair.Total })); } [HttpPost("premium/subscriptions/purchase")] @@ -64,10 +54,29 @@ public class TaxController( { var (purchase, billingAddress) = request.ToDomain(); var result = await previewPremiumTaxCommand.Run(purchase, billingAddress); - return Handle(result.Map(pair => new + return Handle(result.Map(pair => new { pair.Tax, pair.Total })); + } + + [HttpPost("premium/subscriptions/upgrade")] + [InjectUser] + public async Task PreviewPremiumUpgradeProrationAsync( + [BindNever] User user, + [FromBody] PreviewPremiumUpgradeProrationRequest request) + { + var (planType, billingAddress) = request.ToDomain(); + + var result = await previewPremiumUpgradeProrationCommand.Run( + user, + planType, + billingAddress); + + return Handle(result.Map(proration => new { - pair.Tax, - pair.Total + proration.NewPlanProratedAmount, + proration.Credit, + proration.Tax, + proration.Total, + proration.NewPlanProratedMonths })); } } diff --git a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs index 6c56d6db3a..241e595333 100644 --- a/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs +++ b/src/Api/Billing/Controllers/VNext/AccountBillingVNextController.cs @@ -132,8 +132,8 @@ public class AccountBillingVNextController( [BindNever] User user, [FromBody] UpgradePremiumToOrganizationRequest request) { - var (organizationName, key, planType) = request.ToDomain(); - var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType); + var (organizationName, key, planType, billingAddress) = request.ToDomain(); + var result = await upgradePremiumToOrganizationCommand.Run(user, organizationName, key, planType, billingAddress); return Handle(result); } } diff --git a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs index 14375efc78..00b1da4bba 100644 --- a/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs +++ b/src/Api/Billing/Models/Requests/Premium/UpgradePremiumToOrganizationRequest.cs @@ -1,5 +1,6 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json.Serialization; +using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Enums; namespace Bit.Api.Billing.Models.Requests.Premium; @@ -14,24 +15,30 @@ public class UpgradePremiumToOrganizationRequest [Required] [JsonConverter(typeof(JsonStringEnumConverter))] - public ProductTierType Tier { get; set; } + public required ProductTierType TargetProductTierType { get; set; } [Required] - [JsonConverter(typeof(JsonStringEnumConverter))] - public PlanCadenceType Cadence { get; set; } + public required MinimalBillingAddressRequest BillingAddress { get; set; } - private PlanType PlanType => - Tier switch + private PlanType PlanType + { + get { - ProductTierType.Families => PlanType.FamiliesAnnually, - ProductTierType.Teams => Cadence == PlanCadenceType.Monthly - ? PlanType.TeamsMonthly - : PlanType.TeamsAnnually, - ProductTierType.Enterprise => Cadence == PlanCadenceType.Monthly - ? PlanType.EnterpriseMonthly - : PlanType.EnterpriseAnnually, - _ => throw new InvalidOperationException("Cannot upgrade to an Organization subscription that isn't Families, Teams or Enterprise.") - }; + if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise)) + { + throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan."); + } - public (string OrganizationName, string Key, PlanType PlanType) ToDomain() => (OrganizationName, Key, PlanType); + return TargetProductTierType switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => PlanType.TeamsAnnually, + ProductTierType.Enterprise => PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}") + }; + } + } + + public (string OrganizationName, string Key, PlanType PlanType, Core.Billing.Payment.Models.BillingAddress BillingAddress) ToDomain() => + (OrganizationName, Key, PlanType, BillingAddress.ToDomain()); } diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs similarity index 91% rename from src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs rename to src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs index 9233a53c85..ccb8f948af 100644 --- a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs +++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPlanChangeTaxRequest.cs @@ -4,7 +4,7 @@ using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; -namespace Bit.Api.Billing.Models.Requests.Tax; +namespace Bit.Api.Billing.Models.Requests.PreviewInvoice; public record PreviewOrganizationSubscriptionPlanChangeTaxRequest { diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs similarity index 91% rename from src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs rename to src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs index dcc5911f3d..40bec9dec3 100644 --- a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs +++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionPurchaseTaxRequest.cs @@ -4,7 +4,7 @@ using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Organizations.Models; using Bit.Core.Billing.Payment.Models; -namespace Bit.Api.Billing.Models.Requests.Tax; +namespace Bit.Api.Billing.Models.Requests.PreviewInvoice; public record PreviewOrganizationSubscriptionPurchaseTaxRequest { diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionUpdateTaxRequest.cs similarity index 84% rename from src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs rename to src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionUpdateTaxRequest.cs index ae96214ae3..4568fea972 100644 --- a/src/Api/Billing/Models/Requests/Tax/PreviewOrganizationSubscriptionUpdateTaxRequest.cs +++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewOrganizationSubscriptionUpdateTaxRequest.cs @@ -1,7 +1,7 @@ using Bit.Api.Billing.Models.Requests.Organizations; using Bit.Core.Billing.Organizations.Models; -namespace Bit.Api.Billing.Models.Requests.Tax; +namespace Bit.Api.Billing.Models.Requests.PreviewInvoice; public class PreviewOrganizationSubscriptionUpdateTaxRequest { diff --git a/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumSubscriptionPurchaseTaxRequest.cs similarity index 90% rename from src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs rename to src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumSubscriptionPurchaseTaxRequest.cs index 76b8a5a444..d1707cf6de 100644 --- a/src/Api/Billing/Models/Requests/Tax/PreviewPremiumSubscriptionPurchaseTaxRequest.cs +++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumSubscriptionPurchaseTaxRequest.cs @@ -2,7 +2,7 @@ using Bit.Api.Billing.Models.Requests.Payment; using Bit.Core.Billing.Payment.Models; -namespace Bit.Api.Billing.Models.Requests.Tax; +namespace Bit.Api.Billing.Models.Requests.PreviewInvoice; public record PreviewPremiumSubscriptionPurchaseTaxRequest { diff --git a/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs new file mode 100644 index 0000000000..68d7a8d002 --- /dev/null +++ b/src/Api/Billing/Models/Requests/PreviewInvoice/PreviewPremiumUpgradeProrationRequest.cs @@ -0,0 +1,39 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; + +namespace Bit.Api.Billing.Models.Requests.PreviewInvoice; + +public record PreviewPremiumUpgradeProrationRequest +{ + [Required] + [JsonConverter(typeof(JsonStringEnumConverter))] + public required ProductTierType TargetProductTierType { get; set; } + + [Required] + public required MinimalBillingAddressRequest BillingAddress { get; set; } + + private PlanType PlanType + { + get + { + if (TargetProductTierType is not (ProductTierType.Families or ProductTierType.Teams or ProductTierType.Enterprise)) + { + throw new InvalidOperationException($"Cannot upgrade Premium subscription to {TargetProductTierType} plan."); + } + + return TargetProductTierType switch + { + ProductTierType.Families => PlanType.FamiliesAnnually, + ProductTierType.Teams => PlanType.TeamsAnnually, + ProductTierType.Enterprise => PlanType.EnterpriseAnnually, + _ => throw new InvalidOperationException($"Unexpected ProductTierType: {TargetProductTierType}") + }; + } + } + + public (PlanType, BillingAddress) ToDomain() => + (PlanType, BillingAddress.ToDomain()); +} diff --git a/src/Api/Controllers/CollectionsController.cs b/src/Api/Controllers/CollectionsController.cs index b3542cfde2..a1de65299f 100644 --- a/src/Api/Controllers/CollectionsController.cs +++ b/src/Api/Controllers/CollectionsController.cs @@ -81,7 +81,7 @@ public class CollectionsController : Controller [HttpGet("details")] public async Task> GetManyWithDetails(Guid orgId) { - var allOrgCollections = await _collectionRepository.GetManyByOrganizationIdWithPermissionsAsync( + var allOrgCollections = await _collectionRepository.GetManySharedByOrganizationIdWithPermissionsAsync( orgId, _currentContext.UserId.Value, true); var readAllAuthorized = diff --git a/src/Api/Controllers/SsoCookieVendorController.cs b/src/Api/Controllers/SsoCookieVendorController.cs new file mode 100644 index 0000000000..4d45415a4f --- /dev/null +++ b/src/Api/Controllers/SsoCookieVendorController.cs @@ -0,0 +1,119 @@ +using Bit.Core.Settings; +using Bit.Core.Utilities; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Controllers; + +/// +/// Provides an endpoint to read an SSO cookie and redirect to a custom URI +/// scheme. The load balancer/reverse proxy must be configured such that +/// requests to this endpoint do not have the auth cookie stripped. +/// +[Route("sso-cookie-vendor")] +[SelfHosted(SelfHostedOnly = true)] +public class SsoCookieVendorController(IGlobalSettings globalSettings) : Controller +{ + private readonly IGlobalSettings _globalSettings = globalSettings; + private const int _maxShardCount = 20; + private const int _maxUriLength = 8192; + + /// + /// Reads SSO cookie (shards supported) and redirects to the bitwarden:// + /// URI with cookie value(s). + /// + /// + /// 302 redirect on success, 404 if no cookies found, 400 if URI too long, + /// 500 if misconfigured + /// + [HttpGet] + [AllowAnonymous] + public IActionResult Get() + { + var bootstrap = _globalSettings.Communication?.Bootstrap; + if (string.IsNullOrEmpty(bootstrap) || !bootstrap.Equals("ssoCookieVendor", StringComparison.OrdinalIgnoreCase)) + { + return NotFound(); + } + + var cookieName = _globalSettings.Communication?.SsoCookieVendor?.CookieName; + if (string.IsNullOrWhiteSpace(cookieName)) + { + return StatusCode(500, "SSO cookie vendor is not properly configured"); + } + + var uri = string.Empty; + if (TryGetCookie(cookieName, out var cookie)) + { + uri = BuildRedirectUri(cookie); + } + else if (TryGetShardedCookie(cookieName, out var shardedCookie)) + { + uri = BuildRedirectUri(shardedCookie); + } + + if (uri == string.Empty) + { + return NotFound("No SSO cookies found"); + } + + if (uri.Length > _maxUriLength) + { + return BadRequest(); + } + + return Redirect(uri); + } + + private bool TryGetCookie(string cookieName, out Dictionary cookie) + { + cookie = []; + + if (Request.Cookies.TryGetValue(cookieName, out var value) && !string.IsNullOrEmpty(value)) + { + cookie[cookieName] = value; + return true; + } + + return false; + } + + private bool TryGetShardedCookie(string cookieName, out Dictionary cookies) + { + var shardedCookies = new Dictionary(); + + for (var i = 0; i < _maxShardCount; i++) + { + var shardName = $"{cookieName}-{i}"; + if (Request.Cookies.TryGetValue(shardName, out var value) && !string.IsNullOrEmpty(value)) + { + shardedCookies[shardName] = value; + } + else + { + // Stop at first missing shard to maintain order integrity + break; + } + } + + cookies = shardedCookies; + return shardedCookies.Count > 0; + } + + private static string BuildRedirectUri(Dictionary cookies) + { + var queryParams = new List(); + + foreach (var kvp in cookies) + { + var encodedValue = Uri.EscapeDataString(kvp.Value); + queryParams.Add($"{kvp.Key}={encodedValue}"); + } + + // Add a sentinel value so clients can detect a truncated URI, in the + // event a user agent decides the URI is too long. + queryParams.Add("d=1"); + + return $"bitwarden://sso_cookie_vendor?{string.Join("&", queryParams)}"; + } +} diff --git a/src/Api/KeyManagement/Validators/OrganizationUserRotationValidator.cs b/src/Api/KeyManagement/Validators/OrganizationUserRotationValidator.cs index 5023521fe3..835965e2d6 100644 --- a/src/Api/KeyManagement/Validators/OrganizationUserRotationValidator.cs +++ b/src/Api/KeyManagement/Validators/OrganizationUserRotationValidator.cs @@ -34,8 +34,7 @@ public class OrganizationUserRotationValidator : IRotationValidator o.ResetPasswordKey != null).ToList(); - + existing = existing.Where(o => !string.IsNullOrEmpty(o.ResetPasswordKey)).ToList(); foreach (var ou in existing) { diff --git a/src/Api/Models/Public/Response/ErrorResponseModel.cs b/src/Api/Models/Public/Response/ErrorResponseModel.cs index c5bb06d02e..a40b0c9569 100644 --- a/src/Api/Models/Public/Response/ErrorResponseModel.cs +++ b/src/Api/Models/Public/Response/ErrorResponseModel.cs @@ -1,7 +1,5 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; using Microsoft.AspNetCore.Mvc.ModelBinding; namespace Bit.Api.Models.Public.Response; @@ -46,13 +44,14 @@ public class ErrorResponseModel : IResponseModel { } public ErrorResponseModel(string errorKey, string errorValue) - : this(errorKey, new string[] { errorValue }) + : this(errorKey, [errorValue]) { } public ErrorResponseModel(string errorKey, IEnumerable errorValues) : this(new Dictionary> { { errorKey, errorValues } }) { } + [JsonConstructor] public ErrorResponseModel(string message, Dictionary> errors) { Message = message; @@ -70,10 +69,10 @@ public class ErrorResponseModel : IResponseModel /// /// The request model is invalid. [Required] - public string Message { get; set; } + public string Message { get; init; } /// /// If multiple errors occurred, they are listed in dictionary. Errors related to a specific /// request parameter will include a dictionary key describing that parameter. /// - public Dictionary> Errors { get; set; } + public Dictionary>? Errors { get; } } diff --git a/src/Api/Models/Response/ConfigResponseModel.cs b/src/Api/Models/Response/ConfigResponseModel.cs index d748254206..e6ac35bb39 100644 --- a/src/Api/Models/Response/ConfigResponseModel.cs +++ b/src/Api/Models/Response/ConfigResponseModel.cs @@ -18,6 +18,7 @@ public class ConfigResponseModel : ResponseModel public EnvironmentConfigResponseModel Environment { get; set; } public IDictionary FeatureStates { get; set; } public PushSettings Push { get; set; } + public CommunicationSettings Communication { get; set; } public ServerSettingsResponseModel Settings { get; set; } public ConfigResponseModel() : base("config") @@ -48,6 +49,7 @@ public class ConfigResponseModel : ResponseModel FeatureStates = featureService.GetAll(); var webPushEnabled = FeatureStates.TryGetValue(FeatureFlagKeys.WebPush, out var webPushEnabledValue) ? (bool)webPushEnabledValue : false; Push = PushSettings.Build(webPushEnabled, globalSettings); + Communication = CommunicationSettings.Build(globalSettings); Settings = new ServerSettingsResponseModel { DisableUserRegistration = globalSettings.DisableUserRegistration @@ -88,6 +90,40 @@ public class PushSettings } } +public class CommunicationSettings +{ + public CommunicationBootstrapSettings Bootstrap { get; private init; } + + public static CommunicationSettings Build(IGlobalSettings globalSettings) + { + var bootstrap = CommunicationBootstrapSettings.Build(globalSettings); + return bootstrap == null ? null : new() { Bootstrap = bootstrap }; + } +} + +public class CommunicationBootstrapSettings +{ + public string Type { get; private init; } + public string IdpLoginUrl { get; private init; } + public string CookieName { get; private init; } + public string CookieDomain { get; private init; } + + public static CommunicationBootstrapSettings Build(IGlobalSettings globalSettings) + { + return globalSettings.Communication?.Bootstrap?.ToLowerInvariant() switch + { + "ssocookievendor" => new() + { + Type = "ssoCookieVendor", + IdpLoginUrl = globalSettings.Communication?.SsoCookieVendor?.IdpLoginUrl, + CookieName = globalSettings.Communication?.SsoCookieVendor?.CookieName, + CookieDomain = globalSettings.Communication?.SsoCookieVendor?.CookieDomain + }, + _ => null + }; + } +} + public class ServerSettingsResponseModel { public bool DisableUserRegistration { get; set; } diff --git a/src/Api/Program.cs b/src/Api/Program.cs index bf924af47f..baeaab9fdb 100644 --- a/src/Api/Program.cs +++ b/src/Api/Program.cs @@ -8,7 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Api/Public/Controllers/CollectionsController.cs b/src/Api/Public/Controllers/CollectionsController.cs index a567062a5e..28de4dc16d 100644 --- a/src/Api/Public/Controllers/CollectionsController.cs +++ b/src/Api/Public/Controllers/CollectionsController.cs @@ -67,8 +67,9 @@ public class CollectionsController : Controller { var collections = await _collectionRepository.GetManyByOrganizationIdWithAccessAsync(_currentContext.OrganizationId.Value); - var collectionResponses = collections.Select(c => - new CollectionResponseModel(c.Item1, c.Item2.Groups)); + var collectionResponses = collections + .Where(c => c.Item1.Type != CollectionType.DefaultUserCollection) + .Select(c => new CollectionResponseModel(c.Item1, c.Item2.Groups)); var response = new ListResponseModel(collectionResponses); return new JsonResult(response); diff --git a/src/Api/Startup.cs b/src/Api/Startup.cs index b201cef0f3..5b9015b71a 100644 --- a/src/Api/Startup.cs +++ b/src/Api/Startup.cs @@ -14,8 +14,7 @@ using Bit.Api.Tools.Models.Request; using Bit.Api.Vault.Models.Request; using Bit.Core.Auth.Entities; using Bit.SharedWeb.Health; -using Microsoft.IdentityModel.Logging; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Diagnostics.HealthChecks; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -238,8 +237,6 @@ public class Startup GlobalSettings globalSettings, ILogger logger) { - IdentityModelEventSource.ShowPII = true; - // Add general security headers app.UseMiddleware(); @@ -304,44 +301,43 @@ public class Startup // Remove all Bitwarden cloud servers and only register the local server config.PreSerializeFilters.Add((swaggerDoc, httpReq) => { - swaggerDoc.Servers.Clear(); - swaggerDoc.Servers.Add(new OpenApiServer - { - Url = globalSettings.BaseServiceUri.Api, - }); + swaggerDoc.Servers = + [ + new() { + Url = globalSettings.BaseServiceUri.Api, + } + ]; - swaggerDoc.Components.SecuritySchemes.Clear(); - swaggerDoc.Components.SecuritySchemes.Add("oauth2-client-credentials", new OpenApiSecurityScheme + swaggerDoc.Components ??= new OpenApiComponents(); + swaggerDoc.Components.SecuritySchemes = new Dictionary { - Type = SecuritySchemeType.OAuth2, - Flows = new OpenApiOAuthFlows { - ClientCredentials = new OpenApiOAuthFlow + "oauth2-client-credentials", + new OpenApiSecurityScheme { - TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), - Scopes = new Dictionary + Type = SecuritySchemeType.OAuth2, + Flows = new OpenApiOAuthFlows + { + ClientCredentials = new OpenApiOAuthFlow + { + TokenUrl = new Uri($"{globalSettings.BaseServiceUri.Identity}/connect/token"), + Scopes = new Dictionary { { ApiScopes.ApiOrganization, "Organization APIs" } } + } + } } } - }); + }; - swaggerDoc.SecurityRequirements.Clear(); - swaggerDoc.SecurityRequirements.Add(new OpenApiSecurityRequirement - { + swaggerDoc.Security = + [ + new OpenApiSecurityRequirement { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = "oauth2-client-credentials" - } - }, - [ApiScopes.ApiOrganization] - } - }); + [new OpenApiSecuritySchemeReference("oauth2-client-credentials")] = [ApiScopes.ApiOrganization] + }, + ]; }); }); diff --git a/src/Api/Tools/Controllers/ImportCiphersController.cs b/src/Api/Tools/Controllers/ImportCiphersController.cs index 8b3ec5e26c..bebf7cbf29 100644 --- a/src/Api/Tools/Controllers/ImportCiphersController.cs +++ b/src/Api/Tools/Controllers/ImportCiphersController.cs @@ -74,11 +74,6 @@ public class ImportCiphersController : Controller throw new BadRequestException("You cannot import this much data at once."); } - if (model.Ciphers.Any(c => c.ArchivedDate.HasValue)) - { - throw new BadRequestException("You cannot import archived items into an organization."); - } - var orgId = new Guid(organizationId); var collections = model.Collections.Select(c => c.ToCollection(orgId)).ToList(); diff --git a/src/Api/Tools/Controllers/SendsController.cs b/src/Api/Tools/Controllers/SendsController.cs index f9f71d076d..af7fe8f12b 100644 --- a/src/Api/Tools/Controllers/SendsController.cs +++ b/src/Api/Tools/Controllers/SendsController.cs @@ -239,9 +239,8 @@ public class SendsController : Controller { throw new BadRequestException("Could not locate send"); } - if (send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || - send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < DateTime.UtcNow || send.Disabled || - send.DeletionDate < DateTime.UtcNow) + + if (!INonAnonymousSendCommand.SendCanBeAccessed(send)) { throw new NotFoundException(); } @@ -253,9 +252,19 @@ public class SendsController : Controller sendResponse.CreatorIdentifier = creator.Email; } - send.AccessCount++; - await _sendRepository.ReplaceAsync(send); - await _pushNotificationService.PushSyncSendUpdateAsync(send); + /* + * AccessCount is incremented differently for File and Text Send types: + * - Text Sends are incremented at every access + * - File Sends are incremented only when the file is downloaded + * + * Note that this endpoint is initially called for all Send types + */ + if (send.Type == SendType.Text) + { + send.AccessCount++; + await _sendRepository.ReplaceAsync(send); + await _pushNotificationService.PushSyncSendUpdateAsync(send); + } return new ObjectResult(sendResponse); } @@ -272,19 +281,14 @@ public class SendsController : Controller { throw new BadRequestException("Could not locate send"); } - if (send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || - send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < DateTime.UtcNow || send.Disabled || - send.DeletionDate < DateTime.UtcNow) + + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + if (result.Equals(SendAccessResult.Denied)) { throw new NotFoundException(); } - var url = await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId); - - send.AccessCount++; - await _sendRepository.ReplaceAsync(send); - await _pushNotificationService.PushSyncSendUpdateAsync(send); - return new ObjectResult(new SendFileDownloadDataResponseModel() { Id = fileId, Url = url }); } @@ -399,19 +403,7 @@ public class SendsController : Controller [HttpPut("{id}/remove-password")] public async Task PutRemovePassword(string id) { - var userId = _userService.GetProperUserId(User) ?? throw new InvalidOperationException("User ID not found"); - var send = await _sendRepository.GetByIdAsync(new Guid(id)); - if (send == null || send.UserId != userId) - { - throw new NotFoundException(); - } - - // This endpoint exists because PUT preserves existing Password/Emails when not provided. - // This allows clients to update other fields without re-submitting sensitive auth data. - send.Password = null; - send.AuthType = AuthType.None; - await _nonAnonymousSendCommand.SaveSendAsync(send); - return new SendResponseModel(send); + return await this.PutRemoveAuth(id); } // Removes ALL authentication (email or password) if any is present diff --git a/src/Api/Utilities/ServiceCollectionExtensions.cs b/src/Api/Utilities/ServiceCollectionExtensions.cs index b773abf6ef..38d0bf4407 100644 --- a/src/Api/Utilities/ServiceCollectionExtensions.cs +++ b/src/Api/Utilities/ServiceCollectionExtensions.cs @@ -7,7 +7,7 @@ using Bit.SharedWeb.Health; using Bit.SharedWeb.Swagger; using Bit.SharedWeb.Utilities; using Microsoft.AspNetCore.Authorization; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; namespace Bit.Api.Utilities; diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index 9e107b491d..c9ca7525fa 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -976,14 +976,14 @@ public class CiphersController : Controller public async Task DeleteAdmin(Guid id) { var userId = _userService.GetProperUserId(User).Value; - var cipher = await GetByIdAsync(id, userId); + var cipher = await GetByIdAsyncAdmin(id); if (cipher == null || !cipher.OrganizationId.HasValue || !await CanDeleteOrRestoreCipherAsAdminAsync(cipher.OrganizationId.Value, new[] { cipher.Id })) { throw new NotFoundException(); } - await _cipherService.DeleteAsync(cipher, userId, true); + await _cipherService.DeleteAsync(new CipherDetails(cipher), userId, true); } [HttpPost("{id}/delete-admin")] diff --git a/src/Api/Vault/Controllers/SyncController.cs b/src/Api/Vault/Controllers/SyncController.cs index 6ac8d06ba0..b186e4b601 100644 --- a/src/Api/Vault/Controllers/SyncController.cs +++ b/src/Api/Vault/Controllers/SyncController.cs @@ -6,6 +6,7 @@ using Bit.Core; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Context; using Bit.Core.Entities; @@ -44,6 +45,7 @@ public class SyncController : Controller private readonly IFeatureService _featureService; private readonly IApplicationCacheService _applicationCacheService; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; + private readonly IWebAuthnCredentialRepository _webAuthnCredentialRepository; private readonly IUserAccountKeysQuery _userAccountKeysQuery; public SyncController( @@ -61,6 +63,7 @@ public class SyncController : Controller IFeatureService featureService, IApplicationCacheService applicationCacheService, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, + IWebAuthnCredentialRepository webAuthnCredentialRepository, IUserAccountKeysQuery userAccountKeysQuery) { _userService = userService; @@ -77,6 +80,7 @@ public class SyncController : Controller _featureService = featureService; _applicationCacheService = applicationCacheService; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; + _webAuthnCredentialRepository = webAuthnCredentialRepository; _userAccountKeysQuery = userAccountKeysQuery; } @@ -120,6 +124,9 @@ public class SyncController : Controller var organizationIdsClaimingActiveUser = organizationClaimingActiveUser.Select(o => o.Id); var organizationAbilities = await _applicationCacheService.GetOrganizationAbilitiesAsync(); + var webAuthnCredentials = _featureService.IsEnabled(FeatureFlagKeys.PM2035PasskeyUnlock) + ? await _webAuthnCredentialRepository.GetManyByUserIdAsync(user.Id) + : []; UserAccountKeysData userAccountKeys = null; // JIT TDE users and some broken/old users may not have a private key. @@ -130,7 +137,7 @@ public class SyncController : Controller var response = new SyncResponseModel(_globalSettings, user, userAccountKeys, userTwoFactorEnabled, userHasPremiumFromOrganization, organizationAbilities, organizationIdsClaimingActiveUser, organizationUserDetails, providerUserDetails, providerUserOrganizationDetails, - folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends); + folders, collections, ciphers, collectionCiphersGroupDict, excludeDomains, policies, sends, webAuthnCredentials); return response; } diff --git a/src/Api/Vault/Models/Response/SyncResponseModel.cs b/src/Api/Vault/Models/Response/SyncResponseModel.cs index c965320b94..8f90452c6c 100644 --- a/src/Api/Vault/Models/Response/SyncResponseModel.cs +++ b/src/Api/Vault/Models/Response/SyncResponseModel.cs @@ -6,6 +6,9 @@ using Bit.Api.Models.Response; using Bit.Api.Tools.Models.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Models.Data.Provider; +using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Enums; +using Bit.Core.Auth.Models.Api.Response; using Bit.Core.Entities; using Bit.Core.KeyManagement.Models.Api.Response; using Bit.Core.KeyManagement.Models.Data; @@ -39,7 +42,8 @@ public class SyncResponseModel() : ResponseModel("sync") IDictionary> collectionCiphersDict, bool excludeDomains, IEnumerable policies, - IEnumerable sends) + IEnumerable sends, + IEnumerable webAuthnCredentials) : this() { Profile = new ProfileResponseModel(user, userAccountKeysData, organizationUserDetails, providerUserDetails, @@ -57,6 +61,16 @@ public class SyncResponseModel() : ResponseModel("sync") Domains = excludeDomains ? null : new DomainsResponseModel(user, false); Policies = policies?.Select(p => new PolicyResponseModel(p)) ?? new List(); Sends = sends.Select(s => new SendResponseModel(s)); + var webAuthnPrfOptions = webAuthnCredentials + .Where(c => c.GetPrfStatus() == WebAuthnPrfStatus.Enabled) + .Select(c => new WebAuthnPrfDecryptionOption( + c.EncryptedPrivateKey, + c.EncryptedUserKey, + c.CredentialId, + [] // transports as empty array + )) + .ToArray(); + UserDecryption = new UserDecryptionResponseModel { MasterPasswordUnlock = user.HasMasterPassword() @@ -72,7 +86,8 @@ public class SyncResponseModel() : ResponseModel("sync") MasterKeyEncryptedUserKey = user.Key!, Salt = user.Email.ToLowerInvariant() } - : null + : null, + WebAuthnPrfOptions = webAuthnPrfOptions.Length > 0 ? webAuthnPrfOptions : null }; } diff --git a/src/Api/appsettings.Production.json b/src/Api/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/src/Api/appsettings.Production.json +++ b/src/Api/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Billing/Billing.csproj b/src/Billing/Billing.csproj index 27ee9a7ce3..dee50af0bb 100644 --- a/src/Billing/Billing.csproj +++ b/src/Billing/Billing.csproj @@ -21,7 +21,7 @@ - + diff --git a/src/Billing/Services/Implementations/StripeEventService.cs b/src/Billing/Services/Implementations/StripeEventService.cs index 03ca8eeb10..03865b48fe 100644 --- a/src/Billing/Services/Implementations/StripeEventService.cs +++ b/src/Billing/Services/Implementations/StripeEventService.cs @@ -9,6 +9,7 @@ namespace Bit.Billing.Services.Implementations; public class StripeEventService( GlobalSettings globalSettings, + ILogger logger, IOrganizationRepository organizationRepository, IProviderRepository providerRepository, ISetupIntentCache setupIntentCache, @@ -148,26 +149,36 @@ public class StripeEventService( { var setupIntent = await GetSetupIntent(localStripeEvent); + logger.LogInformation("Extracted Setup Intent ({SetupIntentId}) from Stripe 'setup_intent.succeeded' event", setupIntent.Id); + var subscriberId = await setupIntentCache.GetSubscriberIdForSetupIntent(setupIntent.Id); + + logger.LogInformation("Retrieved subscriber ID ({SubscriberId}) from cache for Setup Intent ({SetupIntentId})", subscriberId, setupIntent.Id); + if (subscriberId == null) { + logger.LogError("Cached subscriber ID for Setup Intent ({SetupIntentId}) is null", setupIntent.Id); return null; } var organization = await organizationRepository.GetByIdAsync(subscriberId.Value); + logger.LogInformation("Retrieved organization ({OrganizationId}) via subscriber ID for Setup Intent ({SetupIntentId})", organization?.Id, setupIntent.Id); if (organization is { GatewayCustomerId: not null }) { var organizationCustomer = await stripeFacade.GetCustomer(organization.GatewayCustomerId); + logger.LogInformation("Retrieved customer ({CustomerId}) via organization ID for Setup Intent ({SetupIntentId})", organization.Id, setupIntent.Id); return organizationCustomer.Metadata; } var provider = await providerRepository.GetByIdAsync(subscriberId.Value); + logger.LogInformation("Retrieved provider ({ProviderId}) via subscriber ID for Setup Intent ({SetupIntentId})", provider?.Id, setupIntent.Id); if (provider is not { GatewayCustomerId: not null }) { return null; } var providerCustomer = await stripeFacade.GetCustomer(provider.GatewayCustomerId); + logger.LogInformation("Retrieved customer ({CustomerId}) via provider ID for Setup Intent ({SetupIntentId})", provider.Id, setupIntent.Id); return providerCustomer.Metadata; } } diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index c10368d8c0..4507d9e308 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -1,17 +1,15 @@ -using Bit.Billing.Constants; -using Bit.Billing.Jobs; -using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; -using Quartz; using Stripe; using Stripe.TestHelpers; +using static Bit.Core.Billing.Constants.StripeConstants; using Event = Stripe.Event; namespace Bit.Billing.Services.Implementations; @@ -25,14 +23,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler private readonly IOrganizationSponsorshipRenewCommand _organizationSponsorshipRenewCommand; private readonly IUserService _userService; private readonly IOrganizationRepository _organizationRepository; - private readonly ISchedulerFactory _schedulerFactory; private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; - private readonly IFeatureService _featureService; private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; - private readonly ILogger _logger; private readonly IPushNotificationAdapter _pushNotificationAdapter; public SubscriptionUpdatedHandler( @@ -43,14 +38,11 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler IOrganizationSponsorshipRenewCommand organizationSponsorshipRenewCommand, IUserService userService, IOrganizationRepository organizationRepository, - ISchedulerFactory schedulerFactory, IOrganizationEnableCommand organizationEnableCommand, IOrganizationDisableCommand organizationDisableCommand, IPricingClient pricingClient, - IFeatureService featureService, IProviderRepository providerRepository, IProviderService providerService, - ILogger logger, IPushNotificationAdapter pushNotificationAdapter) { _stripeEventService = stripeEventService; @@ -62,183 +54,147 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler _userService = userService; _organizationRepository = organizationRepository; _providerRepository = providerRepository; - _schedulerFactory = schedulerFactory; _organizationEnableCommand = organizationEnableCommand; _organizationDisableCommand = organizationDisableCommand; _pricingClient = pricingClient; - _featureService = featureService; _providerRepository = providerRepository; _providerService = providerService; - _logger = logger; _pushNotificationAdapter = pushNotificationAdapter; } - /// - /// Handles the event type from Stripe. - /// - /// public async Task HandleAsync(Event parsedEvent) { var subscription = await _stripeEventService.GetSubscription(parsedEvent, true, ["customer", "discounts", "latest_invoice", "test_clock"]); - var (organizationId, userId, providerId) = _stripeEventUtilityService.GetIdsFromMetadata(subscription.Metadata); + SubscriberId subscriberId = subscription; var currentPeriodEnd = subscription.GetCurrentPeriodEnd(); - switch (subscription.Status) + if (SubscriptionWentUnpaid(parsedEvent, subscription)) { - case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired - when organizationId.HasValue: - { - await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); - if (subscription.Status == StripeSubscriptionStatus.Unpaid && - subscription.LatestInvoice is { BillingReason: "subscription_cycle" or "subscription_create" }) - { - await ScheduleCancellationJobAsync(subscription.Id, organizationId.Value); - } - break; - } - case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired when providerId.HasValue: - { - await HandleUnpaidProviderSubscriptionAsync(providerId.Value, parsedEvent, subscription); - break; - } - case StripeSubscriptionStatus.Unpaid or StripeSubscriptionStatus.IncompleteExpired: - { - if (!userId.HasValue) - { - break; - } - - if (await IsPremiumSubscriptionAsync(subscription)) - { - await CancelSubscription(subscription.Id); - await VoidOpenInvoices(subscription.Id); - } - - await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); - - break; - } - case StripeSubscriptionStatus.Incomplete when userId.HasValue: - { - // Handle Incomplete subscriptions for Premium users that have open invoices from failed payments - // This prevents duplicate subscriptions when users retry the subscription flow - if (await IsPremiumSubscriptionAsync(subscription) && - subscription.LatestInvoice is { Status: StripeInvoiceStatus.Open }) - { - await CancelSubscription(subscription.Id); - await VoidOpenInvoices(subscription.Id); - await _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd); - } - - break; - } - case StripeSubscriptionStatus.Active when organizationId.HasValue: - { - await _organizationEnableCommand.EnableAsync(organizationId.Value); - var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); - if (organization != null) - { - await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization); - } - break; - } - case StripeSubscriptionStatus.Active when providerId.HasValue: - { - var provider = await _providerRepository.GetByIdAsync(providerId.Value); - if (provider != null) - { - provider.Enabled = true; - await _providerService.UpdateAsync(provider); - - if (IsProviderSubscriptionNowActive(parsedEvent, subscription)) - { - // Update the CancelAtPeriodEnd subscription option to prevent the now active provider subscription from being cancelled - var subscriptionUpdateOptions = new SubscriptionUpdateOptions { CancelAtPeriodEnd = false }; - await _stripeFacade.UpdateSubscription(subscription.Id, subscriptionUpdateOptions); - } - } - break; - } - case StripeSubscriptionStatus.Active: - { - if (userId.HasValue) - { - await _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd); - } - break; - } + await DisableSubscriberAsync(subscriberId, currentPeriodEnd); + await SetSubscriptionToCancelAsync(subscription); + } + else if (SubscriptionBecameActive(parsedEvent, subscription)) + { + await EnableSubscriberAsync(subscriberId, currentPeriodEnd); + await RemovePendingCancellationAsync(subscription); } - if (organizationId.HasValue) - { - await _organizationService.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd); - if (_stripeEventUtilityService.IsSponsoredSubscription(subscription) && currentPeriodEnd.HasValue) + await subscriberId.Match( + userId => _userService.UpdatePremiumExpirationAsync(userId.Value, currentPeriodEnd), + async organizationId => { - await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd.Value); + await _organizationService.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd); + + if (_stripeEventUtilityService.IsSponsoredSubscription(subscription) && currentPeriodEnd.HasValue) + { + await _organizationSponsorshipRenewCommand.UpdateExpirationDateAsync(organizationId.Value, currentPeriodEnd.Value); + } + + await RemovePasswordManagerCouponIfRemovingSecretsManagerTrialAsync(parsedEvent, subscription); + }, + _ => Task.CompletedTask); + } + + private static bool SubscriptionWentUnpaid( + Event parsedEvent, + Subscription currentSubscription) => + parsedEvent.Data.PreviousAttributes.ToObject() is Subscription + { + Status: + SubscriptionStatus.Trialing or + SubscriptionStatus.Active or + SubscriptionStatus.PastDue + } && currentSubscription is + { + Status: SubscriptionStatus.Unpaid, + LatestInvoice.BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle + }; + + private static bool SubscriptionBecameActive( + Event parsedEvent, + Subscription currentSubscription) => + parsedEvent.Data.PreviousAttributes.ToObject() is Subscription + { + Status: + SubscriptionStatus.Incomplete or + SubscriptionStatus.Unpaid + } && currentSubscription is + { + Status: SubscriptionStatus.Active, + LatestInvoice.BillingReason: BillingReasons.SubscriptionCreate or BillingReasons.SubscriptionCycle + }; + + private Task DisableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => + subscriberId.Match( + userId => _userService.DisablePremiumAsync(userId.Value, currentPeriodEnd), + async organizationId => + { + await _organizationDisableCommand.DisableAsync(organizationId.Value, currentPeriodEnd); + var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); + if (organization != null) + { + await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization); + } + }, + async providerId => + { + var provider = await _providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = false; + await _providerService.UpdateAsync(provider); + } + }); + + private Task EnableSubscriberAsync(SubscriberId subscriberId, DateTime? currentPeriodEnd) => + subscriberId.Match( + userId => _userService.EnablePremiumAsync(userId.Value, currentPeriodEnd), + async organizationId => + { + await _organizationEnableCommand.EnableAsync(organizationId.Value, currentPeriodEnd); + var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); + if (organization != null) + { + await _pushNotificationAdapter.NotifyEnabledChangedAsync(organization); + } + }, + async providerId => + { + var provider = await _providerRepository.GetByIdAsync(providerId.Value); + if (provider != null) + { + provider.Enabled = true; + await _providerService.UpdateAsync(provider); + } + }); + + private async Task SetSubscriptionToCancelAsync(Subscription subscription) + { + if (subscription.TestClock != null) + { + await WaitForTestClockToAdvanceAsync(subscription.TestClock); + } + + var now = subscription.TestClock?.FrozenTime ?? DateTime.UtcNow; + + await _stripeFacade.UpdateSubscription(subscription.Id, new SubscriptionUpdateOptions + { + CancelAt = now.AddDays(7), + ProrationBehavior = ProrationBehavior.None, + CancellationDetails = new SubscriptionCancellationDetailsOptions + { + Comment = $"Automation: Setting unpaid subscription to cancel 7 days from {now:yyyy-MM-dd}." } - - await RemovePasswordManagerCouponIfRemovingSecretsManagerTrialAsync(parsedEvent, subscription); - } - else if (userId.HasValue) - { - await _userService.UpdatePremiumExpirationAsync(userId.Value, currentPeriodEnd); - } + }); } - private async Task CancelSubscription(string subscriptionId) => - await _stripeFacade.CancelSubscription(subscriptionId, new SubscriptionCancelOptions()); - - private async Task VoidOpenInvoices(string subscriptionId) - { - var options = new InvoiceListOptions + private async Task RemovePendingCancellationAsync(Subscription subscription) + => await _stripeFacade.UpdateSubscription(subscription.Id, new SubscriptionUpdateOptions { - Status = StripeInvoiceStatus.Open, - Subscription = subscriptionId - }; - var invoices = await _stripeFacade.ListInvoices(options); - foreach (var invoice in invoices) - { - await _stripeFacade.VoidInvoice(invoice.Id); - } - } - - private async Task IsPremiumSubscriptionAsync(Subscription subscription) - { - var premiumPlans = await _pricingClient.ListPremiumPlans(); - var premiumPriceIds = premiumPlans.SelectMany(p => new[] { p.Seat.StripePriceId, p.Storage.StripePriceId }).ToHashSet(); - return subscription.Items.Any(i => premiumPriceIds.Contains(i.Price.Id)); - } - - /// - /// Checks if the provider subscription status has changed from a non-active to an active status type - /// If the previous status is already active(active,past-due,trialing),canceled,or null, then this will return false. - /// - /// The event containing the previous subscription status - /// The current subscription status - /// A boolean that represents whether the event status has changed from a non-active status to an active status - private static bool IsProviderSubscriptionNowActive(Event parsedEvent, Subscription subscription) - { - if (parsedEvent.Data.PreviousAttributes == null) - { - return false; - } - - var previousSubscription = parsedEvent - .Data - .PreviousAttributes - .ToObject() as Subscription; - - return previousSubscription?.Status switch - { - StripeSubscriptionStatus.IncompleteExpired - or StripeSubscriptionStatus.Paused - or StripeSubscriptionStatus.Incomplete - or StripeSubscriptionStatus.Unpaid - when subscription.Status == StripeSubscriptionStatus.Active => true, - _ => false - }; - } + CancelAtPeriodEnd = false, + ProrationBehavior = ProrationBehavior.None + }); /// /// Removes the Password Manager coupon if the organization is removing the Secrets Manager trial. @@ -275,17 +231,24 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler .PreviousAttributes .ToObject() as Subscription; + // Get all plan IDs that include Secrets Manager support to check if the organization has secret manager in the + // previous and/or current subscriptions. + var planIdsOfPlansWithSecretManager = (await _pricingClient.ListPlans()) + .Where(orgPlan => orgPlan.SupportsSecretsManager && orgPlan.SecretsManager.StripeSeatPlanId != null) + .Select(orgPlan => orgPlan.SecretsManager.StripeSeatPlanId) + .ToHashSet(); + // This being false doesn't necessarily mean that the organization doesn't subscribe to Secrets Manager. // If there are changes to any subscription item, Stripe sends every item in the subscription, both // changed and unchanged. var previousSubscriptionHasSecretsManager = previousSubscription?.Items is not null && previousSubscription.Items.Any( - previousSubscriptionItem => previousSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + previousSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(previousSubscriptionItem.Plan.Id)); var currentSubscriptionHasSecretsManager = subscription.Items.Any( - currentSubscriptionItem => currentSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + currentSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(currentSubscriptionItem.Plan.Id)); if (!previousSubscriptionHasSecretsManager || currentSubscriptionHasSecretsManager) { @@ -298,7 +261,7 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler ?.Id == "sm-standalone"; var subscriptionHasSecretsManagerTrial = subscription.Discounts.Select(discount => discount.Coupon.Id) - .Contains(StripeConstants.CouponIDs.SecretsManagerStandalone); + .Contains(CouponIDs.SecretsManagerStandalone); if (customerHasSecretsManagerTrial) { @@ -311,75 +274,6 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler } } - private async Task ScheduleCancellationJobAsync(string subscriptionId, Guid organizationId) - { - var scheduler = await _schedulerFactory.GetScheduler(); - - var job = JobBuilder.Create() - .WithIdentity($"cancel-sub-{subscriptionId}", "subscription-cancellations") - .UsingJobData("subscriptionId", subscriptionId) - .UsingJobData("organizationId", organizationId.ToString()) - .Build(); - - var trigger = TriggerBuilder.Create() - .WithIdentity($"cancel-trigger-{subscriptionId}", "subscription-cancellations") - .StartAt(DateTimeOffset.UtcNow.AddDays(7)) - .Build(); - - await scheduler.ScheduleJob(job, trigger); - } - - private async Task HandleUnpaidProviderSubscriptionAsync( - Guid providerId, - Event parsedEvent, - Subscription currentSubscription) - { - var provider = await _providerRepository.GetByIdAsync(providerId); - if (provider == null) - { - return; - } - - try - { - provider.Enabled = false; - await _providerService.UpdateAsync(provider); - - if (parsedEvent.Data.PreviousAttributes != null) - { - var previousSubscription = parsedEvent.Data.PreviousAttributes.ToObject() as Subscription; - - if (previousSubscription is - { - Status: - StripeSubscriptionStatus.Trialing or - StripeSubscriptionStatus.Active or - StripeSubscriptionStatus.PastDue - } && currentSubscription is - { - Status: StripeSubscriptionStatus.Unpaid, - LatestInvoice.BillingReason: "subscription_cycle" or "subscription_create" - }) - { - if (currentSubscription.TestClock != null) - { - await WaitForTestClockToAdvanceAsync(currentSubscription.TestClock); - } - - var now = currentSubscription.TestClock?.FrozenTime ?? DateTime.UtcNow; - - var subscriptionUpdateOptions = new SubscriptionUpdateOptions { CancelAt = now.AddDays(7) }; - - await _stripeFacade.UpdateSubscription(currentSubscription.Id, subscriptionUpdateOptions); - } - } - } - catch (Exception exception) - { - _logger.LogError(exception, "An error occurred while trying to disable and schedule subscription cancellation for provider ({ProviderID})", providerId); - } - } - private async Task WaitForTestClockToAdvanceAsync(TestClock testClock) { while (testClock.Status != "ready") diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 004828dc48..ae2a76a7ce 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -627,7 +627,7 @@ public class UpcomingInvoiceHandler( { BaseMonthlyRenewalPrice = (premiumPlan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")), DiscountAmount = $"{coupon.PercentOff}%", - DiscountedMonthlyRenewalPrice = (discountedAnnualRenewalPrice / 12).ToString("C", new CultureInfo("en-US")) + DiscountedAnnualRenewalPrice = discountedAnnualRenewalPrice.ToString("C", new CultureInfo("en-US")) } }; diff --git a/src/Core/AdminConsole/Models/Data/Organizations/Policies/MasterPasswordPolicyData.cs b/src/Core/AdminConsole/Models/Data/Organizations/Policies/MasterPasswordPolicyData.cs index b66244ba5f..228d7a26f1 100644 --- a/src/Core/AdminConsole/Models/Data/Organizations/Policies/MasterPasswordPolicyData.cs +++ b/src/Core/AdminConsole/Models/Data/Organizations/Policies/MasterPasswordPolicyData.cs @@ -1,11 +1,21 @@ -using System.Text.Json.Serialization; +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; namespace Bit.Core.AdminConsole.Models.Data.Organizations.Policies; public class MasterPasswordPolicyData : IPolicyDataModel { + /// + /// Minimum password complexity score (0-4). Null indicates no complexity requirement. + /// [JsonPropertyName("minComplexity")] + [Range(0, 4)] public int? MinComplexity { get; set; } + + /// + /// Minimum password length (12-128). Null indicates no minimum length requirement. + /// [JsonPropertyName("minLength")] + [Range(12, 128)] public int? MinLength { get; set; } [JsonPropertyName("requireLower")] public bool? RequireLower { get; set; } diff --git a/src/Core/AdminConsole/Models/Data/Organizations/Policies/PolicyStatus.cs b/src/Core/AdminConsole/Models/Data/Organizations/Policies/PolicyStatus.cs new file mode 100644 index 0000000000..68c754f6ba --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/Organizations/Policies/PolicyStatus.cs @@ -0,0 +1,26 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.Models.Data.Organizations.Policies; + +public class PolicyStatus +{ + public PolicyStatus(Guid organizationId, PolicyType policyType, Policy? policy = null) + { + OrganizationId = policy?.OrganizationId ?? organizationId; + Data = policy?.Data; + Type = policy?.Type ?? policyType; + Enabled = policy?.Enabled ?? false; + } + + public Guid OrganizationId { get; set; } + public PolicyType Type { get; set; } + public bool Enabled { get; set; } + public string? Data { get; set; } + + public T GetDataModel() where T : IPolicyDataModel, new() + { + return CoreHelpers.LoadClassFromJsonData(Data); + } +} diff --git a/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationEnterpriseTeamsView.html.hbs b/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationEnterpriseTeamsView.html.hbs index 3c8f498403..cf310a19af 100644 --- a/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationEnterpriseTeamsView.html.hbs +++ b/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationEnterpriseTeamsView.html.hbs @@ -778,7 +778,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationFamilyFreeView.html.hbs b/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationFamilyFreeView.html.hbs index c0f838e0c7..40ea484aa2 100644 --- a/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationFamilyFreeView.html.hbs +++ b/src/Core/AdminConsole/Models/Mail/Mailer/OrganizationConfirmation/OrganizationConfirmationFamilyFreeView.html.hbs @@ -946,7 +946,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs index 5783301a0b..bd30112945 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommand.cs @@ -1,5 +1,5 @@ using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -11,7 +11,7 @@ using Microsoft.AspNetCore.Identity; namespace Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; public class AdminRecoverAccountCommand(IOrganizationRepository organizationRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IUserRepository userRepository, IMailService mailService, IEventService eventService, @@ -30,9 +30,8 @@ public class AdminRecoverAccountCommand(IOrganizationRepository organizationRepo } // Enterprise policy must be enabled - var resetPasswordPolicy = - await policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + var resetPasswordPolicy = await policyQuery.RunAsync(orgId, PolicyType.ResetPassword); + if (!resetPasswordPolicy.Enabled) { throw new BadRequestException("Organization does not have the password reset policy enabled."); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs new file mode 100644 index 0000000000..116992146f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs @@ -0,0 +1,53 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Collections; + +public static class CollectionUtils +{ + ///

+ /// Arranges Collection and CollectionUser objects to create default user collections. + /// + /// The organization ID. + /// The IDs for organization users who need default collections. + /// The encrypted string to use as the default collection name. + /// A tuple containing the collections and collection users. + public static (ICollection collections, ICollection collectionUsers) + BuildDefaultUserCollections(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) + { + var now = DateTime.UtcNow; + + var collectionUsers = new List(); + var collections = new List(); + + foreach (var orgUserId in organizationUserIds) + { + var collectionId = CoreHelpers.GenerateComb(); + + collections.Add(new Collection + { + Id = collectionId, + OrganizationId = organizationId, + Name = defaultCollectionName, + CreationDate = now, + RevisionDate = now, + Type = CollectionType.DefaultUserCollection, + DefaultUserCollectionEmail = null + + }); + + collectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = orgUserId, + ReadOnly = false, + HidePasswords = false, + Manage = true, + }); + } + + return (collections, collectionUsers); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index e6cc3da2a2..aec6380ce2 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -157,6 +157,6 @@ public class VerifyOrganizationDomainCommand( var organization = await organizationRepository.GetByIdAsync(domain.OrganizationId); - await mailService.SendClaimedDomainUserEmailAsync(new ClaimedUserDomainClaimedEmails(domainUserEmails, organization)); + await mailService.SendClaimedDomainUserEmailAsync(new ClaimedUserDomainClaimedEmails(domainUserEmails, organization, domain.DomainName)); } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs index 1b488677ae..0292381857 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs @@ -4,9 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.OrganizationConfirmation; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -83,19 +81,10 @@ public class AutomaticallyConfirmOrganizationUserCommand(IOrganizationUserReposi return; } - await collectionRepository.CreateAsync( - new Collection - { - OrganizationId = request.Organization!.Id, - Name = request.DefaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }, - groups: null, - [new CollectionAccessSelection - { - Id = request.OrganizationUser!.Id, - Manage = true - }]); + await collectionRepository.CreateDefaultCollectionsAsync( + request.Organization!.Id, + [request.OrganizationUser!.Id], + request.DefaultUserCollectionName); } catch (Exception ex) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs index 3375120516..f067f529ea 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUsersValidator.cs @@ -3,7 +3,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimed using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; @@ -20,7 +19,7 @@ public class AutomaticallyConfirmOrganizationUsersValidator( IPolicyRequirementQuery policyRequirementQuery, IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator, IUserService userService, - IPolicyRepository policyRepository) : IAutomaticallyConfirmOrganizationUsersValidator + IPolicyQuery policyQuery) : IAutomaticallyConfirmOrganizationUsersValidator { public async Task> ValidateAsync( AutomaticallyConfirmOrganizationUserValidationRequest request) @@ -74,7 +73,7 @@ public class AutomaticallyConfirmOrganizationUsersValidator( } private async Task OrganizationHasAutomaticallyConfirmUsersPolicyEnabledAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) => - await policyRepository.GetByOrganizationIdTypeAsync(request.OrganizationId, PolicyType.AutomaticUserConfirmation) is { Enabled: true } + (await policyQuery.RunAsync(request.OrganizationId, PolicyType.AutomaticUserConfirmation)).Enabled && request.Organization is { UseAutomaticUserConfirmation: true }; private async Task OrganizationUserConformsToTwoFactorRequiredPolicyAsync(AutomaticallyConfirmOrganizationUserValidationRequest request) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 0b82ac7ea4..02f3346ba6 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -14,7 +14,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -294,21 +293,10 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var defaultCollection = new Collection - { - OrganizationId = organizationUser.OrganizationId, - Name = defaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }; - var collectionUser = new CollectionAccessSelection - { - Id = organizationUser.Id, - ReadOnly = false, - HidePasswords = false, - Manage = true - }; - - await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]); + await _collectionRepository.CreateDefaultCollectionsAsync( + organizationUser.OrganizationId, + [organizationUser.Id], + defaultUserCollectionName); } /// @@ -339,7 +327,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - await _collectionRepository.UpsertDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); + await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); } /// diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommand.cs index cd5066d11b..61f428414f 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommand.cs @@ -4,7 +4,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Auth.Models.Business; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Repositories; @@ -19,7 +19,7 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUse public class SendOrganizationInvitesCommand( IUserRepository userRepository, ISsoConfigRepository ssoConfigurationRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IOrgUserInviteTokenableFactory orgUserInviteTokenableFactory, IDataProtectorTokenFactory dataProtectorTokenFactory, IMailService mailService) : ISendOrganizationInvitesCommand @@ -58,7 +58,7 @@ public class SendOrganizationInvitesCommand( // need to check the policy if the org has SSO enabled. var orgSsoLoginRequiredPolicyEnabled = orgSsoEnabled && organization.UsePolicies && - (await policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.RequireSso))?.Enabled == true; + (await policyQuery.RunAsync(organization.Id, PolicyType.RequireSso)).Enabled; // Generate the list of org users and expiring tokens // create helper function to create expiring tokens diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/IRestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/IRestoreOrganizationUserCommand.cs index e5e5bfb482..82ea0a1c11 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/IRestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/IRestoreOrganizationUserCommand.cs @@ -20,7 +20,7 @@ public interface IRestoreOrganizationUserCommand /// /// Revoked user to be restored. /// UserId of the user performing the action. - Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId); + Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, string? defaultCollectionName); /// /// Validates that the requesting user can perform the action. There is also a check done to ensure the organization @@ -50,5 +50,5 @@ public interface IRestoreOrganizationUserCommand /// Passed in from caller to avoid circular dependency /// List of organization user Ids and strings. A successful restoration will have an empty string. /// If an error occurs, the error message will be provided. - Task>> RestoreUsersAsync(Guid organizationId, IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService); + Task>> RestoreUsersAsync(Guid organizationId, IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService, string? defaultCollectionName); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index ec42c8b402..dd9c73a21d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -31,9 +31,10 @@ public class RestoreOrganizationUserCommand( IOrganizationService organizationService, IFeatureService featureService, IPolicyRequirementQuery policyRequirementQuery, + ICollectionRepository collectionRepository, IAutomaticUserConfirmationPolicyEnforcementValidator automaticUserConfirmationPolicyEnforcementValidator) : IRestoreOrganizationUserCommand { - public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId) + public async Task RestoreUserAsync(OrganizationUser organizationUser, Guid? restoringUserId, string defaultCollectionName) { if (restoringUserId.HasValue && organizationUser.UserId == restoringUserId.Value) { @@ -46,7 +47,7 @@ public class RestoreOrganizationUserCommand( throw new BadRequestException("Only owners can restore other owners."); } - await RepositoryRestoreUserAsync(organizationUser); + await RepositoryRestoreUserAsync(organizationUser, defaultCollectionName); await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); if (organizationUser.UserId.HasValue) @@ -57,7 +58,7 @@ public class RestoreOrganizationUserCommand( public async Task RestoreUserAsync(OrganizationUser organizationUser, EventSystemUser systemUser) { - await RepositoryRestoreUserAsync(organizationUser); + await RepositoryRestoreUserAsync(organizationUser, null); // users stored by a system user will not get a default collection at this point. await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored, systemUser); @@ -67,7 +68,7 @@ public class RestoreOrganizationUserCommand( } } - private async Task RepositoryRestoreUserAsync(OrganizationUser organizationUser) + private async Task RepositoryRestoreUserAsync(OrganizationUser organizationUser, string defaultCollectionName) { if (organizationUser.Status != OrganizationUserStatusType.Revoked) { @@ -93,7 +94,7 @@ public class RestoreOrganizationUserCommand( .twoFactorIsEnabled; } - if (organization.PlanType == PlanType.Free) + if (organization.PlanType == PlanType.Free && organizationUser.UserId.HasValue) { await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); } @@ -104,7 +105,16 @@ public class RestoreOrganizationUserCommand( await organizationUserRepository.RestoreAsync(organizationUser.Id, status); - organizationUser.Status = status; + if (organizationUser.UserId.HasValue + && (await policyRequirementQuery.GetAsync(organizationUser.UserId.Value)).State == OrganizationDataOwnershipState.Enabled + && status == OrganizationUserStatusType.Confirmed + && featureService.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + && !string.IsNullOrWhiteSpace(defaultCollectionName)) + { + await collectionRepository.CreateDefaultCollectionsAsync(organizationUser.OrganizationId, + [organizationUser.Id], + defaultCollectionName); + } } private async Task CheckUserForOtherFreeOrganizationOwnershipAsync(OrganizationUser organizationUser) @@ -156,7 +166,8 @@ public class RestoreOrganizationUserCommand( } public async Task>> RestoreUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService) + IEnumerable organizationUserIds, Guid? restoringUserId, IUserService userService, + string defaultCollectionName) { var orgUsers = await organizationUserRepository.GetManyAsync(organizationUserIds); var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) @@ -224,12 +235,14 @@ public class RestoreOrganizationUserCommand( await organizationUserRepository.RestoreAsync(organizationUser.Id, status); organizationUser.Status = status; - await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + if (organizationUser.UserId.HasValue) { await pushNotificationService.PushSyncOrgKeysAsync(organizationUser.UserId.Value); } + await eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + result.Add(Tuple.Create(organizationUser, "")); } catch (BadRequestException e) @@ -238,9 +251,37 @@ public class RestoreOrganizationUserCommand( } } + if (featureService.IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore)) + { + await CreateDefaultCollectionsForConfirmedUsersAsync(organizationId, defaultCollectionName, + result.Where(r => r.Item2 == "").Select(x => x.Item1).ToList()); + } + return result; } + private async Task CreateDefaultCollectionsForConfirmedUsersAsync(Guid organizationId, string defaultCollectionName, + ICollection restoredUsers) + { + if (!string.IsNullOrWhiteSpace(defaultCollectionName)) + { + var organizationUsersDataOwnershipEnabled = (await policyRequirementQuery + .GetManyByOrganizationIdAsync(organizationId)) + .ToList(); + + var usersToCreateDefaultCollectionsFor = restoredUsers.Where(x => + organizationUsersDataOwnershipEnabled.Contains(x.Id) + && x.Status == OrganizationUserStatusType.Confirmed).ToList(); + + if (usersToCreateDefaultCollectionsFor.Count != 0) + { + await collectionRepository.CreateDefaultCollectionsAsync(organizationId, + usersToCreateDefaultCollectionsFor.Select(x => x.Id), + defaultCollectionName); + } + } + } + private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, bool userHasTwoFactorEnabled) { // An invited OrganizationUser isn't linked with a user account yet, so these checks are irrelevant diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyQuery.cs new file mode 100644 index 0000000000..02eeeaa847 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/IPolicyQuery.cs @@ -0,0 +1,17 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; + +public interface IPolicyQuery +{ + /// + /// Retrieves a summary view of an organization's usage of a policy specified by the . + /// + /// + /// This query is the entrypoint for consumers interested in understanding how a particular + /// has been applied to an organization; the resultant is not indicative of explicit + /// policy configuration. + /// + Task RunAsync(Guid organizationId, PolicyType policyType); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyQuery.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyQuery.cs new file mode 100644 index 0000000000..0ee6f9ab06 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/PolicyQuery.cs @@ -0,0 +1,14 @@ +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.Repositories; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; + +public class PolicyQuery(IPolicyRepository policyRepository) : IPolicyQuery +{ + public async Task RunAsync(Guid organizationId, PolicyType policyType) + { + var dbPolicy = await policyRepository.GetByOrganizationIdTypeAsync(organizationId, policyType); + return new PolicyStatus(organizationId, policyType, dbPolicy); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs index 3430f33a77..9b6cf86257 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/AutomaticUserConfirmationPolicyRequirement.cs @@ -19,7 +19,7 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements /// Collection of policy details that apply to this user id public class AutomaticUserConfirmationPolicyRequirement(IEnumerable policyDetails) : IPolicyRequirement { - public bool CannotBeGrantedEmergencyAccess() => policyDetails.Any(); + public bool CannotHaveEmergencyAccess() => policyDetails.Any(); public bool CannotJoinProvider() => policyDetails.Any(); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs index f69935715d..6e0c3aa8d9 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyServiceCollectionExtensions.cs @@ -18,6 +18,7 @@ public static class PolicyServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddPolicyValidators(); diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index 7a47baa65a..104a5751ff 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -57,14 +57,15 @@ public class OrganizationDataOwnershipPolicyValidator( var userOrgIds = requirements .Select(requirement => requirement.GetDefaultCollectionRequestOnPolicyEnable(policyUpdate.OrganizationId)) .Where(request => request.ShouldCreateDefaultCollection) - .Select(request => request.OrganizationUserId); + .Select(request => request.OrganizationUserId) + .ToList(); if (!userOrgIds.Any()) { return; } - await collectionRepository.UpsertDefaultCollectionsAsync( + await collectionRepository.CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, userOrgIds, defaultCollectionName); diff --git a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs index da7a77000b..d79923fdd1 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs @@ -21,7 +21,9 @@ public interface IOrganizationRepository : IRepository Task> GetOwnerEmailAddressesById(Guid organizationId); /// - /// Gets the organizations that have a verified domain matching the user's email domain. + /// Gets the organizations that have claimed the user's account. Currently, only one organization may claim a user. + /// This requires that the organization has claimed the user's domain and the user is an organization member. + /// It excludes invited members. /// Task> GetByVerifiedUserEmailDomainAsync(Guid userId); diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index 41622c24b7..583e86ab4d 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -28,21 +28,21 @@ public interface IOrganizationUserRepository : IRepository /// The id of the OrganizationUser /// A tuple containing the OrganizationUser and its associated collections - Task<(OrganizationUserUserDetails? OrganizationUser, ICollection Collections)> GetDetailsByIdWithCollectionsAsync(Guid id); + Task<(OrganizationUserUserDetails? OrganizationUser, ICollection Collections)> GetDetailsByIdWithSharedCollectionsAsync(Guid id); /// /// Returns the OrganizationUsers and their associated collections (excluding DefaultUserCollections). /// /// The id of the organization /// Whether to include groups - /// Whether to include collections + /// Whether to include shared collections /// A list of OrganizationUserUserDetails - Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups = false, bool includeCollections = false); + Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups = false, bool includeSharedCollections = false); /// /// /// This method is optimized for performance. /// Reduces database round trips by fetching all data in fewer queries. /// - Task> GetManyDetailsByOrganizationAsync_vNext(Guid organizationId, bool includeGroups = false, bool includeCollections = false); + Task> GetManyDetailsByOrganizationAsync_vNext(Guid organizationId, bool includeGroups = false, bool includeSharedCollections = false); Task> GetManyDetailsByUserAsync(Guid userId, OrganizationUserStatusType? status = null); Task GetDetailsByUserAsync(Guid userId, Guid organizationId, diff --git a/src/Core/AdminConsole/Services/IOrganizationService.cs b/src/Core/AdminConsole/Services/IOrganizationService.cs index f509ac8358..af0d327a68 100644 --- a/src/Core/AdminConsole/Services/IOrganizationService.cs +++ b/src/Core/AdminConsole/Services/IOrganizationService.cs @@ -27,7 +27,6 @@ public interface IOrganizationService OrganizationUserInvite invite, string externalId); Task> InviteUsersAsync(Guid organizationId, Guid? invitingUserId, EventSystemUser? systemUser, IEnumerable<(OrganizationUserInvite invite, string externalId)> invites); - Task>> ResendInvitesAsync(Guid organizationId, Guid? invitingUserId, IEnumerable organizationUsersId); Task UpdateUserResetPasswordEnrollmentAsync(Guid organizationId, Guid userId, string resetPasswordKey, Guid? callingUserId); Task DeleteSsoUserAsync(Guid userId, Guid? organizationId); Task ReplaceAndUpdateCacheAsync(Organization org, EventType? orgEvent = null); diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index e1fcbb970d..d87bc65042 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -14,7 +14,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; -using Bit.Core.AdminConsole.Utilities.DebuggingInstruments; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Constants; @@ -49,7 +48,7 @@ public class OrganizationService : IOrganizationService private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; private readonly IStripePaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IPolicyService _policyService; private readonly ISsoUserRepository _ssoUserRepository; private readonly IGlobalSettings _globalSettings; @@ -76,7 +75,7 @@ public class OrganizationService : IOrganizationService IEventService eventService, IApplicationCacheService applicationCacheService, IStripePaymentService paymentService, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IPolicyService policyService, ISsoUserRepository ssoUserRepository, IGlobalSettings globalSettings, @@ -103,7 +102,7 @@ public class OrganizationService : IOrganizationService _eventService = eventService; _applicationCacheService = applicationCacheService; _paymentService = paymentService; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _policyService = policyService; _ssoUserRepository = ssoUserRepository; _globalSettings = globalSettings; @@ -718,32 +717,6 @@ public class OrganizationService : IOrganizationService return (allOrgUsers, events); } - public async Task>> ResendInvitesAsync(Guid organizationId, - Guid? invitingUserId, - IEnumerable organizationUsersId) - { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - _logger.LogUserInviteStateDiagnostics(orgUsers); - - var org = await GetOrgById(organizationId); - - var result = new List>(); - foreach (var orgUser in orgUsers) - { - if (orgUser.Status != OrganizationUserStatusType.Invited || orgUser.OrganizationId != organizationId) - { - result.Add(Tuple.Create(orgUser, "User invalid.")); - continue; - } - - await SendInviteAsync(orgUser, org, false); - result.Add(Tuple.Create(orgUser, "")); - } - - return result; - } - - private async Task SendInvitesAsync(IEnumerable orgUsers, Organization organization) => await _sendOrganizationInvitesCommand.SendInvitesAsync(new SendInvitesRequest(orgUsers, organization)); @@ -862,9 +835,8 @@ public class OrganizationService : IOrganizationService } // Make sure the organization has the policy enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organizationId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + var resetPasswordPolicy = await _policyQuery.RunAsync(organizationId, PolicyType.ResetPassword); + if (!resetPasswordPolicy.Enabled) { throw new BadRequestException("Organization does not have the password reset policy enabled."); } diff --git a/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs index 84e63f2a20..d533ca88cf 100644 --- a/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs +++ b/src/Core/AdminConsole/Utilities/PolicyDataValidator.cs @@ -1,4 +1,5 @@ -using System.Text.Json; +using System.ComponentModel.DataAnnotations; +using System.Text.Json; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; @@ -30,7 +31,8 @@ public static class PolicyDataValidator switch (policyType) { case PolicyType.MasterPassword: - CoreHelpers.LoadClassFromJsonData(json); + var masterPasswordData = CoreHelpers.LoadClassFromJsonData(json); + ValidateModel(masterPasswordData, policyType); break; case PolicyType.SendOptions: CoreHelpers.LoadClassFromJsonData(json); @@ -44,11 +46,24 @@ public static class PolicyDataValidator } catch (JsonException ex) { - var fieldInfo = !string.IsNullOrEmpty(ex.Path) ? $": field '{ex.Path}' has invalid type" : ""; + var fieldName = !string.IsNullOrEmpty(ex.Path) ? ex.Path.TrimStart('$', '.') : null; + var fieldInfo = !string.IsNullOrEmpty(fieldName) ? $": {fieldName} has an invalid value" : ""; throw new BadRequestException($"Invalid data for {policyType} policy{fieldInfo}."); } } + private static void ValidateModel(object model, PolicyType policyType) + { + var validationContext = new ValidationContext(model); + var validationResults = new List(); + + if (!Validator.TryValidateObject(model, validationContext, validationResults, true)) + { + var errors = string.Join(", ", validationResults.Select(r => r.ErrorMessage)); + throw new BadRequestException($"Invalid data for {policyType} policy: {errors}"); + } + } + /// /// Validates and deserializes policy metadata based on the policy type. /// diff --git a/src/Core/Auth/Entities/EmergencyAccess.cs b/src/Core/Auth/Entities/EmergencyAccess.cs index 36aaf46a8c..df66541d2a 100644 --- a/src/Core/Auth/Entities/EmergencyAccess.cs +++ b/src/Core/Auth/Entities/EmergencyAccess.cs @@ -1,7 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations; using Bit.Core.Auth.Enums; using Bit.Core.Entities; using Bit.Core.Utilities; @@ -14,8 +11,8 @@ public class EmergencyAccess : ITableObject public Guid GrantorId { get; set; } public Guid? GranteeId { get; set; } [MaxLength(256)] - public string Email { get; set; } - public string KeyEncrypted { get; set; } + public string? Email { get; set; } + public string? KeyEncrypted { get; set; } public EmergencyAccessType Type { get; set; } public EmergencyAccessStatusType Status { get; set; } public short WaitTimeDays { get; set; } diff --git a/src/Core/Auth/Enums/EmergencyAccessStatusType.cs b/src/Core/Auth/Enums/EmergencyAccessStatusType.cs index d817d6a950..45bd5fdee4 100644 --- a/src/Core/Auth/Enums/EmergencyAccessStatusType.cs +++ b/src/Core/Auth/Enums/EmergencyAccessStatusType.cs @@ -19,7 +19,7 @@ public enum EmergencyAccessStatusType : byte /// RecoveryInitiated = 3, /// - /// The grantee has excercised their emergency access. + /// The grantee has exercised their emergency access. /// RecoveryApproved = 4, } diff --git a/src/Core/Auth/Identity/TokenProviders/WebAuthnTokenProvider.cs b/src/Core/Auth/Identity/TokenProviders/WebAuthnTokenProvider.cs index 60fb2c5635..a6b1a27713 100644 --- a/src/Core/Auth/Identity/TokenProviders/WebAuthnTokenProvider.cs +++ b/src/Core/Auth/Identity/TokenProviders/WebAuthnTokenProvider.cs @@ -147,16 +147,12 @@ public class WebAuthnTokenProvider : IUserTwoFactorTokenProvider return keys; } - // Support up to 5 keys - for (var i = 1; i <= 5; i++) + // Load all WebAuthn credentials stored in metadata. The number of allowed credentials + // is controlled by credential registration. + foreach (var kvp in provider.MetaData.Where(k => k.Key.StartsWith("Key"))) { - var keyName = $"Key{i}"; - if (provider.MetaData.TryGetValue(keyName, out var value)) - { - var key = new TwoFactorProvider.WebAuthnData((dynamic)value); - - keys.Add(new Tuple(keyName, key)); - } + var key = new TwoFactorProvider.WebAuthnData((dynamic)kvp.Value); + keys.Add(new Tuple(kvp.Key, key)); } return keys; diff --git a/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs b/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs index 0ac7dbbcb4..cb66540a6b 100644 --- a/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs +++ b/src/Core/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModel.cs @@ -1,6 +1,6 @@ -#nullable enable -using Bit.Core.Entities; +using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.Utilities; namespace Bit.Core.Auth.Models.Api.Request.Accounts; @@ -21,19 +21,32 @@ public class RegisterFinishRequestModel : IValidatableObject public required string Email { get; set; } public string? EmailVerificationToken { get; set; } + public MasterPasswordAuthenticationDataRequestModel? MasterPasswordAuthentication { get; set; } + public MasterPasswordUnlockDataRequestModel? MasterPasswordUnlock { get; set; } + + // PM-28143 - Remove property below (made optional during migration to MasterPasswordUnlockData) [StringLength(1000)] - public required string MasterPasswordHash { get; set; } + // Made optional but there will still be a thrown error if it does not exist either here or + // in the MasterPasswordAuthenticationData. + public string? MasterPasswordHash { get; set; } [StringLength(50)] public string? MasterPasswordHint { get; set; } - public required string UserSymmetricKey { get; set; } + // PM-28143 - Remove property below (made optional during migration to MasterPasswordUnlockData) + // Made optional but there will still be a thrown error if it does not exist either here or + // in the MasterPasswordAuthenticationData. + public string? UserSymmetricKey { get; set; } public required KeysRequestModel UserAsymmetricKeys { get; set; } - public required KdfType Kdf { get; set; } - public required int KdfIterations { get; set; } + // PM-28143 - Remove line below (made optional during migration to MasterPasswordUnlockData) + public KdfType? Kdf { get; set; } + // PM-28143 - Remove line below (made optional during migration to MasterPasswordUnlockData) + public int? KdfIterations { get; set; } + // PM-28143 - Remove line below public int? KdfMemory { get; set; } + // PM-28143 - Remove line below public int? KdfParallelism { get; set; } public Guid? OrganizationUserId { get; set; } @@ -54,11 +67,14 @@ public class RegisterFinishRequestModel : IValidatableObject { Email = Email, MasterPasswordHint = MasterPasswordHint, - Kdf = Kdf, - KdfIterations = KdfIterations, - KdfMemory = KdfMemory, - KdfParallelism = KdfParallelism, - Key = UserSymmetricKey, + Kdf = (KdfType)(MasterPasswordUnlock?.Kdf.KdfType ?? Kdf)!, + KdfIterations = (int)(MasterPasswordUnlock?.Kdf.Iterations ?? KdfIterations)!, + // KdfMemory and KdfParallelism are optional (only used for Argon2id) + KdfMemory = MasterPasswordUnlock?.Kdf.Memory ?? KdfMemory, + KdfParallelism = MasterPasswordUnlock?.Kdf.Parallelism ?? KdfParallelism, + // PM-28827 To be added when MasterPasswordSalt is added to the user column + // MasterPasswordSalt = MasterPasswordUnlock?.Salt ?? Email.ToLower().Trim(), + Key = MasterPasswordUnlock?.MasterKeyWrappedUserKey ?? UserSymmetricKey }; UserAsymmetricKeys.ToUser(user); @@ -72,7 +88,9 @@ public class RegisterFinishRequestModel : IValidatableObject { return RegisterFinishTokenType.EmailVerification; } - if (!string.IsNullOrEmpty(OrgInviteToken) && OrganizationUserId.HasValue) + if (!string.IsNullOrEmpty(OrgInviteToken) + && OrganizationUserId.HasValue + && OrganizationUserId.Value != Guid.Empty) { return RegisterFinishTokenType.OrganizationInvite; } @@ -80,11 +98,15 @@ public class RegisterFinishRequestModel : IValidatableObject { return RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan; } - if (!string.IsNullOrWhiteSpace(AcceptEmergencyAccessInviteToken) && AcceptEmergencyAccessId.HasValue) + if (!string.IsNullOrWhiteSpace(AcceptEmergencyAccessInviteToken) + && AcceptEmergencyAccessId.HasValue + && AcceptEmergencyAccessId.Value != Guid.Empty) { return RegisterFinishTokenType.EmergencyAccessInvite; } - if (!string.IsNullOrWhiteSpace(ProviderInviteToken) && ProviderUserId.HasValue) + if (!string.IsNullOrWhiteSpace(ProviderInviteToken) + && ProviderUserId.HasValue + && ProviderUserId.Value != Guid.Empty) { return RegisterFinishTokenType.ProviderInvite; } @@ -92,9 +114,156 @@ public class RegisterFinishRequestModel : IValidatableObject throw new InvalidOperationException("Invalid token type."); } - public IEnumerable Validate(ValidationContext validationContext) { - return KdfSettingsValidator.Validate(Kdf, KdfIterations, KdfMemory, KdfParallelism); + // 1. Authentication data containing hash and hash at root level check + if (MasterPasswordAuthentication != null && MasterPasswordHash != null) + { + if (MasterPasswordAuthentication.MasterPasswordAuthenticationHash != MasterPasswordHash) + { + yield return new ValidationResult( + $"{nameof(MasterPasswordAuthentication.MasterPasswordAuthenticationHash)} and root level {nameof(MasterPasswordHash)} provided and are not equal. Only provide one.", + [nameof(MasterPasswordAuthentication.MasterPasswordAuthenticationHash), nameof(MasterPasswordHash)]); + } + } // 1.5 if there is no master password hash that is unacceptable even though they are both optional in the model + else if (MasterPasswordAuthentication == null && MasterPasswordHash == null) + { + yield return new ValidationResult( + $"{nameof(MasterPasswordAuthentication.MasterPasswordAuthenticationHash)} and {nameof(MasterPasswordHash)} not found on request, one needs to be defined.", + [nameof(MasterPasswordAuthentication.MasterPasswordAuthenticationHash), nameof(MasterPasswordHash)]); + } + + // 2. Validate kdf settings. + if (MasterPasswordUnlock != null) + { + foreach (var validationResult in KdfSettingsValidator.Validate(MasterPasswordUnlock.ToData().Kdf)) + { + yield return validationResult; + } + } + + if (MasterPasswordAuthentication != null) + { + foreach (var validationResult in KdfSettingsValidator.Validate(MasterPasswordAuthentication.ToData().Kdf)) + { + yield return validationResult; + } + } + + // 3. Validate root kdf values if kdf values are not in the unlock and authentication. + if (MasterPasswordUnlock == null && MasterPasswordAuthentication == null) + { + var hasMissingRequiredKdfInputs = false; + if (Kdf == null) + { + yield return new ValidationResult($"{nameof(Kdf)} not found on RequestModel", [nameof(Kdf)]); + hasMissingRequiredKdfInputs = true; + } + if (KdfIterations == null) + { + yield return new ValidationResult($"{nameof(KdfIterations)} not found on RequestModel", [nameof(KdfIterations)]); + hasMissingRequiredKdfInputs = true; + } + + if (!hasMissingRequiredKdfInputs) + { + foreach (var validationResult in KdfSettingsValidator.Validate( + Kdf!.Value, + KdfIterations!.Value, + KdfMemory, + KdfParallelism)) + { + yield return validationResult; + } + } + } + else if (MasterPasswordUnlock == null && MasterPasswordAuthentication != null) + { + // Authentication provided but Unlock missing + yield return new ValidationResult($"{nameof(MasterPasswordUnlock)} not found on RequestModel", [nameof(MasterPasswordUnlock)]); + } + else if (MasterPasswordUnlock != null && MasterPasswordAuthentication == null) + { + // Unlock provided but Authentication missing + yield return new ValidationResult($"{nameof(MasterPasswordAuthentication)} not found on RequestModel", [nameof(MasterPasswordAuthentication)]); + } + + // 3. Lastly, validate access token type and presence. Must be done last because of yield break. + RegisterFinishTokenType tokenType; + var tokenTypeResolved = true; + try + { + tokenType = GetTokenType(); + } + catch (InvalidOperationException) + { + tokenTypeResolved = false; + tokenType = default; + } + + if (!tokenTypeResolved) + { + yield return new ValidationResult("No valid registration token provided"); + yield break; + } + + switch (tokenType) + { + case RegisterFinishTokenType.EmailVerification: + if (string.IsNullOrEmpty(EmailVerificationToken)) + { + yield return new ValidationResult( + $"{nameof(EmailVerificationToken)} absent when processing register/finish.", + [nameof(EmailVerificationToken)]); + } + break; + case RegisterFinishTokenType.OrganizationInvite: + if (string.IsNullOrEmpty(OrgInviteToken)) + { + yield return new ValidationResult( + $"{nameof(OrgInviteToken)} absent when processing register/finish.", + [nameof(OrgInviteToken)]); + } + break; + case RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan: + if (string.IsNullOrEmpty(OrgSponsoredFreeFamilyPlanToken)) + { + yield return new ValidationResult( + $"{nameof(OrgSponsoredFreeFamilyPlanToken)} absent when processing register/finish.", + [nameof(OrgSponsoredFreeFamilyPlanToken)]); + } + break; + case RegisterFinishTokenType.EmergencyAccessInvite: + if (string.IsNullOrEmpty(AcceptEmergencyAccessInviteToken)) + { + yield return new ValidationResult( + $"{nameof(AcceptEmergencyAccessInviteToken)} absent when processing register/finish.", + [nameof(AcceptEmergencyAccessInviteToken)]); + } + if (!AcceptEmergencyAccessId.HasValue || AcceptEmergencyAccessId.Value == Guid.Empty) + { + yield return new ValidationResult( + $"{nameof(AcceptEmergencyAccessId)} absent when processing register/finish.", + [nameof(AcceptEmergencyAccessId)]); + } + break; + case RegisterFinishTokenType.ProviderInvite: + if (string.IsNullOrEmpty(ProviderInviteToken)) + { + yield return new ValidationResult( + $"{nameof(ProviderInviteToken)} absent when processing register/finish.", + [nameof(ProviderInviteToken)]); + } + if (!ProviderUserId.HasValue || ProviderUserId.Value == Guid.Empty) + { + yield return new ValidationResult( + $"{nameof(ProviderUserId)} absent when processing register/finish.", + [nameof(ProviderUserId)]); + } + break; + default: + yield return new ValidationResult("Invalid registration finish request"); + break; + } } } diff --git a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs index aa8a298200..bc22ab1266 100644 --- a/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs +++ b/src/Core/Auth/Models/Api/Response/UserDecryptionOptions.cs @@ -45,13 +45,19 @@ public class WebAuthnPrfDecryptionOption { public string EncryptedPrivateKey { get; } public string EncryptedUserKey { get; } + public string CredentialId { get; } + public string[] Transports { get; } public WebAuthnPrfDecryptionOption( string encryptedPrivateKey, - string encryptedUserKey) + string encryptedUserKey, + string credentialId, + string[]? transports = null) { EncryptedPrivateKey = encryptedPrivateKey; EncryptedUserKey = encryptedUserKey; + CredentialId = credentialId; + Transports = transports ?? []; } } diff --git a/src/Core/Auth/Models/Data/EmergencyAccessDetails.cs b/src/Core/Auth/Models/Data/EmergencyAccessDetails.cs index 03661c7276..86c1e6953f 100644 --- a/src/Core/Auth/Models/Data/EmergencyAccessDetails.cs +++ b/src/Core/Auth/Models/Data/EmergencyAccessDetails.cs @@ -1,16 +1,16 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Entities; namespace Bit.Core.Auth.Models.Data; public class EmergencyAccessDetails : EmergencyAccess { - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } - public string GranteeAvatarColor { get; set; } - public string GrantorName { get; set; } - public string GrantorEmail { get; set; } - public string GrantorAvatarColor { get; set; } + public string? GranteeName { get; set; } + public string? GranteeEmail { get; set; } + public string? GranteeAvatarColor { get; set; } + public string? GrantorName { get; set; } + /// + /// Grantor email is assumed not null because in order to create an emergency access the grantor must be an existing user. + /// + public required string GrantorEmail { get; set; } + public string? GrantorAvatarColor { get; set; } } diff --git a/src/Core/Auth/Models/Data/EmergencyAccessNotify.cs b/src/Core/Auth/Models/Data/EmergencyAccessNotify.cs index 1c0d4bfe8b..492d565717 100644 --- a/src/Core/Auth/Models/Data/EmergencyAccessNotify.cs +++ b/src/Core/Auth/Models/Data/EmergencyAccessNotify.cs @@ -1,14 +1,10 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - - -using Bit.Core.Auth.Entities; +using Bit.Core.Auth.Entities; namespace Bit.Core.Auth.Models.Data; public class EmergencyAccessNotify : EmergencyAccess { - public string GrantorEmail { get; set; } - public string GranteeName { get; set; } - public string GranteeEmail { get; set; } + public string? GrantorEmail { get; set; } + public string? GranteeName { get; set; } + public string? GranteeEmail { get; set; } } diff --git a/src/Core/Auth/Repositories/IEmergencyAccessRepository.cs b/src/Core/Auth/Repositories/IEmergencyAccessRepository.cs index 63ec04106e..5b4ad47180 100644 --- a/src/Core/Auth/Repositories/IEmergencyAccessRepository.cs +++ b/src/Core/Auth/Repositories/IEmergencyAccessRepository.cs @@ -2,8 +2,6 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.KeyManagement.UserKey; -#nullable enable - namespace Bit.Core.Repositories; public interface IEmergencyAccessRepository : IRepository @@ -11,7 +9,17 @@ public interface IEmergencyAccessRepository : IRepository Task GetCountByGrantorIdEmailAsync(Guid grantorId, string email, bool onlyRegisteredUsers); Task> GetManyDetailsByGrantorIdAsync(Guid grantorId); Task> GetManyDetailsByGranteeIdAsync(Guid granteeId); + /// + /// Fetches emergency access details by EmergencyAccess id and grantor id + /// + /// Emergency Access Id + /// Grantor Id + /// EmergencyAccessDetails or null Task GetDetailsByIdGrantorIdAsync(Guid id, Guid grantorId); + /// + /// Database call to fetch emergency accesses that need notification emails sent through a Job + /// + /// collection of EmergencyAccessNotify objects that require notification Task> GetManyToNotifyAsync(); Task> GetExpiredRecoveriesAsync(); @@ -22,4 +30,11 @@ public interface IEmergencyAccessRepository : IRepository /// A list of emergency access with updated keys UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid grantorId, IEnumerable emergencyAccessKeys); + + /// + /// Deletes multiple emergency access records by their IDs + /// + /// Ids of records to be deleted + /// void + Task DeleteManyAsync(ICollection emergencyAccessIds); } diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index 0cb8b68042..3c4f1ef85d 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -5,9 +5,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -21,7 +21,7 @@ namespace Bit.Core.Auth.Services; public class SsoConfigService : ISsoConfigService { private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; @@ -29,14 +29,14 @@ public class SsoConfigService : ISsoConfigService public SsoConfigService( ISsoConfigRepository ssoConfigRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IEventService eventService, IVNextSavePolicyCommand vNextSavePolicyCommand) { _ssoConfigRepository = ssoConfigRepository; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _eventService = eventService; @@ -114,14 +114,14 @@ public class SsoConfigService : ISsoConfigService throw new BadRequestException("Organization cannot use Key Connector."); } - var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg); - if (singleOrgPolicy is not { Enabled: true }) + var singleOrgPolicy = await _policyQuery.RunAsync(config.OrganizationId, PolicyType.SingleOrg); + if (!singleOrgPolicy.Enabled) { throw new BadRequestException("Key Connector requires the Single Organization policy to be enabled."); } - var ssoPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso); - if (ssoPolicy is not { Enabled: true }) + var ssoPolicy = await _policyQuery.RunAsync(config.OrganizationId, PolicyType.RequireSso); + if (!ssoPolicy.Enabled) { throw new BadRequestException("Key Connector requires the Single Sign-On Authentication policy to be enabled."); } diff --git a/src/Core/Auth/UserFeatures/EmergencyAccess/Commands/DeleteEmergencyAccessCommand.cs b/src/Core/Auth/UserFeatures/EmergencyAccess/Commands/DeleteEmergencyAccessCommand.cs new file mode 100644 index 0000000000..40779b266a --- /dev/null +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/Commands/DeleteEmergencyAccessCommand.cs @@ -0,0 +1,107 @@ +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Interfaces; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Mail.Mailer; +using Bit.Core.Repositories; + +namespace Bit.Core.Auth.UserFeatures.EmergencyAccess.Commands; + +public class DeleteEmergencyAccessCommand( + IEmergencyAccessRepository _emergencyAccessRepository, + IMailer mailer) : IDeleteEmergencyAccessCommand +{ + /// + public async Task DeleteByIdGrantorIdAsync(Guid emergencyAccessId, Guid grantorId) + { + var emergencyAccessDetails = await _emergencyAccessRepository.GetDetailsByIdGrantorIdAsync(emergencyAccessId, grantorId); + + if (emergencyAccessDetails == null || emergencyAccessDetails.GrantorId != grantorId) + { + throw new BadRequestException("Emergency Access not valid."); + } + + var (grantorEmails, granteeEmails) = await DeleteEmergencyAccessAsync([emergencyAccessDetails]); + + // Send notification email to grantor + await SendEmergencyAccessRemoveGranteesEmailAsync(grantorEmails, granteeEmails); + return emergencyAccessDetails; + } + + /// + public async Task?> DeleteAllByGrantorIdAsync(Guid grantorId) + { + var emergencyAccessDetails = await _emergencyAccessRepository.GetManyDetailsByGrantorIdAsync(grantorId); + + // if there is nothing return an empty array and do not send an email + if (emergencyAccessDetails.Count == 0) + { + return emergencyAccessDetails; + } + + var (grantorEmails, granteeEmails) = await DeleteEmergencyAccessAsync(emergencyAccessDetails); + + // Send notification email to grantor + await SendEmergencyAccessRemoveGranteesEmailAsync(grantorEmails, granteeEmails); + + return emergencyAccessDetails; + } + + /// + public async Task?> DeleteAllByGranteeIdAsync(Guid granteeId) + { + var emergencyAccessDetails = await _emergencyAccessRepository.GetManyDetailsByGranteeIdAsync(granteeId); + + // if there is nothing return an empty array + if (emergencyAccessDetails == null || emergencyAccessDetails.Count == 0) + { + return emergencyAccessDetails; + } + + var (grantorEmails, granteeEmails) = await DeleteEmergencyAccessAsync(emergencyAccessDetails); + + // Send notification email to grantor(s) + await SendEmergencyAccessRemoveGranteesEmailAsync(grantorEmails, granteeEmails); + + return emergencyAccessDetails; + } + + private async Task<(HashSet grantorEmails, HashSet granteeEmails)> DeleteEmergencyAccessAsync(IEnumerable emergencyAccessDetails) + { + var grantorEmails = new HashSet(); + var granteeEmails = new HashSet(); + + await _emergencyAccessRepository.DeleteManyAsync([.. emergencyAccessDetails.Select(ea => ea.Id)]); + + foreach (var details in emergencyAccessDetails) + { + granteeEmails.Add(details.GranteeEmail ?? string.Empty); + grantorEmails.Add(details.GrantorEmail); + } + + return (grantorEmails, granteeEmails); + } + + /// + /// Sends an email notification to the grantor about removed grantees. + /// + /// The email addresses of the grantors to notify when deleting by grantee + /// The formatted identifiers of the removed grantees to include in the email + /// + private async Task SendEmergencyAccessRemoveGranteesEmailAsync(IEnumerable grantorEmails, IEnumerable formattedGranteeIdentifiers) + { + foreach (var email in grantorEmails) + { + var emailViewModel = new EmergencyAccessRemoveGranteesMail + { + ToEmails = [email], + View = new EmergencyAccessRemoveGranteesMailView + { + RemovedGranteeEmails = formattedGranteeIdentifiers + } + }; + + await mailer.SendEmail(emailViewModel); + } + } +} diff --git a/src/Core/Auth/Services/EmergencyAccess/EmergencyAccessService.cs b/src/Core/Auth/UserFeatures/EmergencyAccess/EmergencyAccessService.cs similarity index 95% rename from src/Core/Auth/Services/EmergencyAccess/EmergencyAccessService.cs rename to src/Core/Auth/UserFeatures/EmergencyAccess/EmergencyAccessService.cs index 0072f85e61..6552f4bc69 100644 --- a/src/Core/Auth/Services/EmergencyAccess/EmergencyAccessService.cs +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/EmergencyAccessService.cs @@ -4,7 +4,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Data; @@ -19,7 +18,7 @@ using Bit.Core.Vault.Models.Data; using Bit.Core.Vault.Repositories; using Bit.Core.Vault.Services; -namespace Bit.Core.Auth.Services; +namespace Bit.Core.Auth.UserFeatures.EmergencyAccess; public class EmergencyAccessService : IEmergencyAccessService { @@ -61,7 +60,7 @@ public class EmergencyAccessService : IEmergencyAccessService _removeOrganizationUserCommand = removeOrganizationUserCommand; } - public async Task InviteAsync(User grantorUser, string emergencyContactEmail, EmergencyAccessType accessType, int waitTime) + public async Task InviteAsync(User grantorUser, string emergencyContactEmail, EmergencyAccessType accessType, int waitTime) { if (!await _userService.CanAccessPremium(grantorUser)) { @@ -73,7 +72,7 @@ public class EmergencyAccessService : IEmergencyAccessService throw new BadRequestException("You cannot use Emergency Access Takeover because you are using Key Connector."); } - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Entities.EmergencyAccess { GrantorId = grantorUser.Id, Email = emergencyContactEmail.ToLowerInvariant(), @@ -113,7 +112,7 @@ public class EmergencyAccessService : IEmergencyAccessService await SendInviteAsync(emergencyAccess, NameOrEmail(grantorUser)); } - public async Task AcceptUserAsync(Guid emergencyAccessId, User granteeUser, string token, IUserService userService) + public async Task AcceptUserAsync(Guid emergencyAccessId, User granteeUser, string token, IUserService userService) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); if (emergencyAccess == null) @@ -175,7 +174,7 @@ public class EmergencyAccessService : IEmergencyAccessService await _emergencyAccessRepository.DeleteAsync(emergencyAccess); } - public async Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId) + public async Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); if (emergencyAccess == null || emergencyAccess.Status != EmergencyAccessStatusType.Accepted || @@ -201,7 +200,7 @@ public class EmergencyAccessService : IEmergencyAccessService return emergencyAccess; } - public async Task SaveAsync(EmergencyAccess emergencyAccess, User grantorUser) + public async Task SaveAsync(Entities.EmergencyAccess emergencyAccess, User grantorUser) { if (!await _userService.CanAccessPremium(grantorUser)) { @@ -311,7 +310,7 @@ public class EmergencyAccessService : IEmergencyAccessService } // TODO PM-21687: rename this to something like InitiateRecoveryTakeoverAsync - public async Task<(EmergencyAccess, User)> TakeoverAsync(Guid emergencyAccessId, User granteeUser) + public async Task<(Entities.EmergencyAccess, User)> TakeoverAsync(Guid emergencyAccessId, User granteeUser) { var emergencyAccess = await _emergencyAccessRepository.GetByIdAsync(emergencyAccessId); @@ -429,7 +428,7 @@ public class EmergencyAccessService : IEmergencyAccessService return await _cipherService.GetAttachmentDownloadDataAsync(cipher, attachmentId); } - private async Task SendInviteAsync(EmergencyAccess emergencyAccess, string invitingUsersName) + private async Task SendInviteAsync(Entities.EmergencyAccess emergencyAccess, string invitingUsersName) { var token = _dataProtectorTokenizer.Protect(new EmergencyAccessInviteTokenable(emergencyAccess, _globalSettings.OrganizationInviteExpirationHours)); await _mailService.SendEmergencyAccessInviteEmailAsync(emergencyAccess, invitingUsersName, token); @@ -449,7 +448,7 @@ public class EmergencyAccessService : IEmergencyAccessService */ //TODO PM-21687: this IsValidRequest() checks the validity based on the granteeUser. There should be a complementary method for the grantorUser private static bool IsValidRequest( - EmergencyAccess availableAccess, + Entities.EmergencyAccess availableAccess, User requestingUser, EmergencyAccessType requestedAccessType) { diff --git a/src/Core/Auth/Services/EmergencyAccess/IEmergencyAccessService.cs b/src/Core/Auth/UserFeatures/EmergencyAccess/IEmergencyAccessService.cs similarity index 93% rename from src/Core/Auth/Services/EmergencyAccess/IEmergencyAccessService.cs rename to src/Core/Auth/UserFeatures/EmergencyAccess/IEmergencyAccessService.cs index de695bbd7d..860ae8bfb6 100644 --- a/src/Core/Auth/Services/EmergencyAccess/IEmergencyAccessService.cs +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/IEmergencyAccessService.cs @@ -1,5 +1,4 @@ using Bit.Core.AdminConsole.Entities; -using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; using Bit.Core.Entities; @@ -7,7 +6,7 @@ using Bit.Core.Enums; using Bit.Core.Services; using Bit.Core.Vault.Models.Data; -namespace Bit.Core.Auth.Services; +namespace Bit.Core.Auth.UserFeatures.EmergencyAccess; public interface IEmergencyAccessService { @@ -20,7 +19,7 @@ public interface IEmergencyAccessService /// Type of emergency access allowed to the emergency contact /// The amount of time to pass before the invite is auto confirmed /// a new Emergency Access object - Task InviteAsync(User grantorUser, string emergencyContactEmail, EmergencyAccessType accessType, int waitTime); + Task InviteAsync(User grantorUser, string emergencyContactEmail, EmergencyAccessType accessType, int waitTime); /// /// Sends an invite to the emergency contact associated with the emergency access id. /// @@ -37,7 +36,7 @@ public interface IEmergencyAccessService /// the tokenable that was sent via email /// service dependency /// void - Task AcceptUserAsync(Guid emergencyAccessId, User granteeUser, string token, IUserService userService); + Task AcceptUserAsync(Guid emergencyAccessId, User granteeUser, string token, IUserService userService); /// /// The creator of the emergency access request can delete the request. /// @@ -53,7 +52,7 @@ public interface IEmergencyAccessService /// The grantor user key encrypted by the grantee public key; grantee.PubicKey(grantor.User.Key) /// Id of grantor user /// emergency access object associated with the Id passed in - Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); + Task ConfirmUserAsync(Guid emergencyAccessId, string key, Guid grantorId); /// /// Fetches an emergency access object. The grantor user must own the object being fetched. /// @@ -67,7 +66,7 @@ public interface IEmergencyAccessService /// emergency access entity being updated /// grantor user /// void - Task SaveAsync(EmergencyAccess emergencyAccess, User grantorUser); + Task SaveAsync(Entities.EmergencyAccess emergencyAccess, User grantorUser); /// /// Initiates the recovery process. For either Takeover or view. Will send an email to the Grantor User notifying of the initiation. /// @@ -107,7 +106,7 @@ public interface IEmergencyAccessService /// Id of entity being accessed /// grantee user of the emergency access entity /// emergency access entity and the grantorUser - Task<(EmergencyAccess, User)> TakeoverAsync(Guid emergencyAccessId, User granteeUser); + Task<(Entities.EmergencyAccess, User)> TakeoverAsync(Guid emergencyAccessId, User granteeUser); /// /// Updates the grantor's password hash and updates the key for the EmergencyAccess entity. /// diff --git a/src/Core/Auth/UserFeatures/EmergencyAccess/Interfaces/IDeleteEmergencyAccessCommand.cs b/src/Core/Auth/UserFeatures/EmergencyAccess/Interfaces/IDeleteEmergencyAccessCommand.cs new file mode 100644 index 0000000000..efdd864d60 --- /dev/null +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/Interfaces/IDeleteEmergencyAccessCommand.cs @@ -0,0 +1,35 @@ +using Bit.Core.Auth.Models.Data; +using Bit.Core.Exceptions; + +namespace Bit.Core.Auth.UserFeatures.EmergencyAccess.Interfaces; + +/// +/// Command for deleting emergency access records based on the grantor's user ID. +/// +public interface IDeleteEmergencyAccessCommand +{ + /// + /// Deletes a single emergency access record for the specified grantor. + /// + /// The ID of the emergency access record to delete. + /// The ID of the grantor user who owns the emergency access record. + /// A task representing the asynchronous operation. + /// + /// Thrown when the emergency access record is not found or does not belong to the specified grantor. + /// + Task DeleteByIdGrantorIdAsync(Guid emergencyAccessId, Guid grantorId); + + /// + /// Deletes all emergency access records for the specified grantor. + /// + /// The ID of the grantor user whose emergency access records should be deleted. + /// A collection of the deleted emergency access records. + Task?> DeleteAllByGrantorIdAsync(Guid grantorId); + + /// + /// Deletes all emergency access records for the specified grantee. + /// + /// The ID of the grantee user whose emergency access records should be deleted. + /// A collection of the deleted emergency access records. + Task?> DeleteAllByGranteeIdAsync(Guid granteeId); +} diff --git a/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.cs b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.cs new file mode 100644 index 0000000000..52106a089f --- /dev/null +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.cs @@ -0,0 +1,15 @@ +using Bit.Core.Platform.Mail.Mailer; + +namespace Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail; + +public class EmergencyAccessRemoveGranteesMailView : BaseMailView +{ + + public required IEnumerable RemovedGranteeEmails { get; set; } + public static string EmergencyAccessHelpPageUrl => "https://bitwarden.com/help/emergency-access/"; +} + +public class EmergencyAccessRemoveGranteesMail : BaseMail +{ + public override string Subject { get; set; } = "Emergency contacts removed"; +} diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.html.hbs similarity index 64% rename from src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs rename to src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.html.hbs index 7d30fdcbe4..3512d2526e 100644 --- a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.html.hbs +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.html.hbs @@ -29,26 +29,24 @@ .mj-outlook-group-fix { width:100% !important; } - - + + - - - - + + + + - + - + - - + +
- + - - + +
- + -
+ - +
- - + + - - + +
- +
- +
- + - + - + - +
- +
- + - +
- +
- +

- Verify your email to access this Bitwarden Send +

- +
- +
- + - +
- + - + - +
- +
- - - + + +
- +
- +
- +
- +
- - + + - - + +
- +
- +
- - + + - + - + - - + +
- + -
+ - - + +
- + -
- - + + +
- - + +
- - - - -
- - - - + - - - - - - - - - - - - - - +
- -
Your verification code is:
- +
+ +
The following emergency contacts have been removed from your account: +
    + {{#each RemovedGranteeEmails}} +
  • {{this}}
  • + {{/each}} +
+ Learn more about emergency access.
+
- -
{{Token}}
- -
- -
- -
- -
This code expires in {{Expiry}} minutes. After that, you'll need - to verify your email again.
- -
- -
- +
- +
- +
- - - - - -
- - - - - - - -
- - -
- - - - - - - -
- - - - - - - - - -
- -

- Bitwarden Send transmits sensitive, temporary information to - others easily and securely. Learn more about - Bitwarden Send - or - sign up - to try it today. -

- -
- -
- -
- - -
- -
- - + +
- +
- - + + - - - - - - -
- - - - - - - -
- - - -
- - - - - - - -
- - -
- - - - - - - - - -
- -

- Learn more about Bitwarden -

- Find user guides, product documentation, and videos on the - Bitwarden Help Center.
- -
- -
- - - -
- - - - - - - - - -
- -
- - -
- -
- - - -
- -
- - - - + - + - - + +
- +
- +
- + - + - + - +
- - + + - + - + - +
@@ -501,15 +309,15 @@
- + - + - +
@@ -524,15 +332,15 @@
- + - + - +
@@ -547,15 +355,15 @@
- + - + - +
@@ -570,15 +378,15 @@
- + - + - +
@@ -593,15 +401,15 @@
- + - + - +
@@ -616,15 +424,15 @@
- + - + - +
@@ -639,22 +447,22 @@
- - + +
- +

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

@@ -663,29 +471,28 @@ bitwarden.com | Learn why we include this

- +
- +
- +
- +
- - + + - - + +
- + - \ No newline at end of file diff --git a/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.text.hbs b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.text.hbs new file mode 100644 index 0000000000..0a8446dd17 --- /dev/null +++ b/src/Core/Auth/UserFeatures/EmergencyAccess/Mail/EmergencyAccessRemoveGranteesMailView.text.hbs @@ -0,0 +1,7 @@ +The following emergency contacts have been removed from your account: + +{{#each RemovedGranteeEmails}} + {{this}} +{{/each}} + +Learn more about emergency access at {{EmergencyAccessHelpPageUrl}} diff --git a/src/Core/Auth/Services/EmergencyAccess/readme.md b/src/Core/Auth/UserFeatures/EmergencyAccess/readme.md similarity index 100% rename from src/Core/Auth/Services/EmergencyAccess/readme.md rename to src/Core/Auth/UserFeatures/EmergencyAccess/readme.md diff --git a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs index 4a0e9c2cf5..ba63afb54c 100644 --- a/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs +++ b/src/Core/Auth/UserFeatures/Registration/Implementations/RegisterUserCommand.cs @@ -1,6 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; @@ -27,7 +27,7 @@ public class RegisterUserCommand : IRegisterUserCommand private readonly IGlobalSettings _globalSettings; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IOrganizationRepository _organizationRepository; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IOrganizationDomainRepository _organizationDomainRepository; private readonly IFeatureService _featureService; @@ -50,7 +50,7 @@ public class RegisterUserCommand : IRegisterUserCommand IGlobalSettings globalSettings, IOrganizationUserRepository organizationUserRepository, IOrganizationRepository organizationRepository, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IOrganizationDomainRepository organizationDomainRepository, IFeatureService featureService, IDataProtectionProvider dataProtectionProvider, @@ -65,7 +65,7 @@ public class RegisterUserCommand : IRegisterUserCommand _globalSettings = globalSettings; _organizationUserRepository = organizationUserRepository; _organizationRepository = organizationRepository; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _organizationDomainRepository = organizationDomainRepository; _featureService = featureService; @@ -246,9 +246,9 @@ public class RegisterUserCommand : IRegisterUserCommand var orgUser = await _organizationUserRepository.GetByIdAsync(orgUserId.Value); if (orgUser != null) { - var twoFactorPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(orgUser.OrganizationId, + var twoFactorPolicy = await _policyQuery.RunAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication); - if (twoFactorPolicy != null && twoFactorPolicy.Enabled) + if (twoFactorPolicy.Enabled) { user.SetTwoFactorProviders(new Dictionary { diff --git a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs index 6249d1cb1c..356d5bf2bc 100644 --- a/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs +++ b/src/Core/Auth/UserFeatures/UserServiceCollectionExtensions.cs @@ -1,5 +1,7 @@ using Bit.Core.Auth.Sso; using Bit.Core.Auth.UserFeatures.DeviceTrust; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Commands; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Interfaces; using Bit.Core.Auth.UserFeatures.Registration; using Bit.Core.Auth.UserFeatures.Registration.Implementations; using Bit.Core.Auth.UserFeatures.TdeOffboardingPassword.Interfaces; @@ -23,6 +25,7 @@ public static class UserServiceCollectionExtensions { services.AddScoped(); services.AddDeviceTrustCommands(); + services.AddEmergencyAccessCommands(); services.AddUserPasswordCommands(); services.AddUserRegistrationCommands(); services.AddWebAuthnLoginCommands(); @@ -36,6 +39,11 @@ public static class UserServiceCollectionExtensions services.AddScoped(); } + private static void AddEmergencyAccessCommands(this IServiceCollection services) + { + services.AddScoped(); + } + public static void AddUserKeyCommands(this IServiceCollection services, IGlobalSettings globalSettings) { services.AddScoped(); diff --git a/src/Core/Billing/Caches/Implementations/SetupIntentDistributedCache.cs b/src/Core/Billing/Caches/Implementations/SetupIntentDistributedCache.cs index 8833c928fe..514898e53c 100644 --- a/src/Core/Billing/Caches/Implementations/SetupIntentDistributedCache.cs +++ b/src/Core/Billing/Caches/Implementations/SetupIntentDistributedCache.cs @@ -1,11 +1,13 @@ using Microsoft.Extensions.Caching.Distributed; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; namespace Bit.Core.Billing.Caches.Implementations; public class SetupIntentDistributedCache( [FromKeyedServices("persistent")] - IDistributedCache distributedCache) : ISetupIntentCache + IDistributedCache distributedCache, + ILogger logger) : ISetupIntentCache { public async Task GetSetupIntentIdForSubscriber(Guid subscriberId) { @@ -17,11 +19,12 @@ public class SetupIntentDistributedCache( { var cacheKey = GetCacheKeyBySetupIntentId(setupIntentId); var value = await distributedCache.GetStringAsync(cacheKey); - if (string.IsNullOrEmpty(value) || !Guid.TryParse(value, out var subscriberId)) + if (!string.IsNullOrEmpty(value) && Guid.TryParse(value, out var subscriberId)) { - return null; + return subscriberId; } - return subscriberId; + logger.LogError("Subscriber ID value ({Value}) cached for Setup Intent ({SetupIntentId}) is null or not a valid Guid", value, setupIntentId); + return null; } public async Task RemoveSetupIntentForSubscriber(Guid subscriberId) diff --git a/src/Core/Billing/Constants/StripeConstants.cs b/src/Core/Billing/Constants/StripeConstants.cs index e9c34d7e06..524a819b28 100644 --- a/src/Core/Billing/Constants/StripeConstants.cs +++ b/src/Core/Billing/Constants/StripeConstants.cs @@ -70,10 +70,6 @@ public static class StripeConstants public const string InvoiceApproved = "invoice_approved"; public const string OrganizationId = "organizationId"; public const string PayPalTransactionId = "btPayPalTransactionId"; - public const string PreviousAdditionalStorage = "previous_additional_storage"; - public const string PreviousPeriodEndDate = "previous_period_end_date"; - public const string PreviousPremiumPriceId = "previous_premium_price_id"; - public const string PreviousPremiumUserId = "previous_premium_user_id"; public const string ProviderId = "providerId"; public const string Region = "region"; public const string RetiredBraintreeCustomerId = "btCustomerId_old"; diff --git a/src/Core/Billing/Enums/DiscountAudienceType.cs b/src/Core/Billing/Enums/DiscountAudienceType.cs new file mode 100644 index 0000000000..98ebd9163d --- /dev/null +++ b/src/Core/Billing/Enums/DiscountAudienceType.cs @@ -0,0 +1,13 @@ +namespace Bit.Core.Billing.Enums; + +/// +/// Defines the target audience for subscription discounts using an extensible strategy pattern. +/// Each audience type maps to specific eligibility rules implemented via IDiscountAudienceFilter. +/// +public enum DiscountAudienceType +{ + /// + /// Discount applies to users who have never had a subscription before. + /// + UserHasNoPreviousSubscriptions = 0 +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index c61c4e6279..ddf3479aa3 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -59,6 +59,7 @@ public static class ServiceCollectionExtensions services.AddScoped(); services.AddScoped(); services.AddTransient(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); } diff --git a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs index a5a9e3e9c9..5734babc31 100644 --- a/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs +++ b/src/Core/Billing/Payment/Commands/UpdatePaymentMethodCommand.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Utilities; using Braintree; @@ -23,6 +24,7 @@ public interface IUpdatePaymentMethodCommand public class UpdatePaymentMethodCommand( IBraintreeGateway braintreeGateway, + IBraintreeService braintreeService, IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, @@ -94,6 +96,8 @@ public class UpdatePaymentMethodCommand( await setupIntentCache.Set(subscriber.Id, setupIntent.Id); + _logger.LogInformation("{Command}: Successfully cached Setup Intent ({SetupIntentId}) for subscriber ({SubscriberID})", CommandName, setupIntent.Id, subscriber.Id); + await UnlinkBraintreeCustomerAsync(customer); return MaskedPaymentMethod.From(setupIntent); @@ -121,12 +125,10 @@ public class UpdatePaymentMethodCommand( Customer customer, string token) { - Braintree.Customer braintreeCustomer; + var braintreeCustomer = await braintreeService.GetCustomer(customer); - if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + if (braintreeCustomer != null) { - braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); - await ReplaceBraintreePaymentMethodAsync(braintreeCustomer, token); } else diff --git a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs index e03a785278..58e9930b87 100644 --- a/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs +++ b/src/Core/Billing/Payment/Queries/GetPaymentMethodQuery.cs @@ -1,11 +1,10 @@ using Bit.Core.Billing.Caches; -using Bit.Core.Billing.Constants; using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; using Bit.Core.Entities; +using Bit.Core.Services; using Braintree; -using Microsoft.Extensions.Logging; using Stripe; namespace Bit.Core.Billing.Payment.Queries; @@ -16,8 +15,7 @@ public interface IGetPaymentMethodQuery } public class GetPaymentMethodQuery( - IBraintreeGateway braintreeGateway, - ILogger logger, + IBraintreeService braintreeService, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService) : IGetPaymentMethodQuery @@ -32,19 +30,12 @@ public class GetPaymentMethodQuery( return null; } - // First check for PayPal - if (customer.Metadata.TryGetValue(StripeConstants.MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + // First check for a PayPal account + var braintreeCustomer = await braintreeService.GetCustomer(customer); + + if (braintreeCustomer is { DefaultPaymentMethod: PayPalAccount payPalAccount }) { - var braintreeCustomer = await braintreeGateway.Customer.FindAsync(braintreeCustomerId); - - if (braintreeCustomer.DefaultPaymentMethod is PayPalAccount payPalAccount) - { - return new MaskedPayPalAccount { Email = payPalAccount.Email }; - } - - logger.LogWarning("Subscriber ({SubscriberID}) has a linked Braintree customer ({BraintreeCustomerId}) with no PayPal account.", subscriber.Id, braintreeCustomerId); - - return null; + return new MaskedPayPalAccount { Email = payPalAccount.Email }; } // Then check for a bank account pending verification diff --git a/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs b/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs new file mode 100644 index 0000000000..af2a8bdacb --- /dev/null +++ b/src/Core/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommand.cs @@ -0,0 +1,166 @@ +using Bit.Core.Billing.Commands; +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Premium.Models; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Microsoft.Extensions.Logging; +using Stripe; + +namespace Bit.Core.Billing.Premium.Commands; + +/// +/// Previews the proration details for upgrading a Premium user subscription to an Organization +/// plan by using the Stripe API to create an invoice preview, prorated, for the upgrade. +/// +public interface IPreviewPremiumUpgradeProrationCommand +{ + /// + /// Calculates the tax, total cost, and proration credit for upgrading a Premium subscription to an Organization plan. + /// + /// The user with an active Premium subscription. + /// The target organization plan type. + /// The billing address for tax calculation. + /// The proration details for the upgrade including costs, credits, tax, and time remaining. + Task> Run( + User user, + PlanType targetPlanType, + BillingAddress billingAddress); +} + +public class PreviewPremiumUpgradeProrationCommand( + ILogger logger, + IPricingClient pricingClient, + IStripeAdapter stripeAdapter) + : BaseBillingCommand(logger), + IPreviewPremiumUpgradeProrationCommand +{ + public Task> Run( + User user, + PlanType targetPlanType, + BillingAddress billingAddress) => HandleAsync(async () => + { + if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" }) + { + return new BadRequest("User does not have an active Premium subscription."); + } + + var currentSubscription = await stripeAdapter.GetSubscriptionAsync( + user.GatewaySubscriptionId, + new SubscriptionGetOptions { Expand = ["customer"] }); + var premiumPlans = await pricingClient.ListPremiumPlans(); + var passwordManagerItem = currentSubscription.Items.Data.FirstOrDefault(i => + premiumPlans.Any(p => p.Seat.StripePriceId == i.Price.Id)); + + if (passwordManagerItem == null) + { + return new BadRequest("Premium subscription password manager item not found."); + } + + var usersPremiumPlan = premiumPlans.First(p => p.Seat.StripePriceId == passwordManagerItem.Price.Id); + var targetPlan = await pricingClient.GetPlanOrThrow(targetPlanType); + var subscriptionItems = new List(); + var storageItem = currentSubscription.Items.Data.FirstOrDefault(i => + i.Price.Id == usersPremiumPlan.Storage.StripePriceId); + + // Delete the storage item if it exists for this user's plan + if (storageItem != null) + { + subscriptionItems.Add(new InvoiceSubscriptionDetailsItemOptions + { + Id = storageItem.Id, + Deleted = true + }); + } + + // Hardcode seats to 1 for upgrade flow + if (targetPlan.HasNonSeatBasedPasswordManagerPlan()) + { + subscriptionItems.Add(new InvoiceSubscriptionDetailsItemOptions + { + Id = passwordManagerItem.Id, + Price = targetPlan.PasswordManager.StripePlanId, + Quantity = 1 + }); + } + else + { + subscriptionItems.Add(new InvoiceSubscriptionDetailsItemOptions + { + Id = passwordManagerItem.Id, + Price = targetPlan.PasswordManager.StripeSeatPlanId, + Quantity = 1 + }); + } + + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions { Enabled = true }, + Customer = user.GatewayCustomerId, + Subscription = user.GatewaySubscriptionId, + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }, + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = subscriptionItems, + ProrationBehavior = StripeConstants.ProrationBehavior.AlwaysInvoice + } + }; + + var invoicePreview = await stripeAdapter.CreateInvoicePreviewAsync(options); + var proration = GetProration(invoicePreview, passwordManagerItem); + + return proration; + }); + + private static PremiumUpgradeProration GetProration(Invoice invoicePreview, SubscriptionItem passwordManagerItem) => new() + { + NewPlanProratedAmount = GetNewPlanProratedAmountFromInvoice(invoicePreview), + Credit = GetProrationCreditFromInvoice(invoicePreview), + Tax = Convert.ToDecimal(invoicePreview.TotalTaxes.Sum(invoiceTotalTax => invoiceTotalTax.Amount)) / 100, + Total = Convert.ToDecimal(invoicePreview.Total) / 100, + // Use invoice periodEnd here instead of UtcNow so that testing with Stripe time clocks works correctly. And if there is no test clock, + // (like in production), the previewInvoice's periodEnd is the same as UtcNow anyway because of the proration behavior (always_invoice) + NewPlanProratedMonths = CalculateNewPlanProratedMonths(invoicePreview.PeriodEnd, passwordManagerItem.CurrentPeriodEnd) + }; + + private static decimal GetProrationCreditFromInvoice(Invoice invoicePreview) + { + // Extract proration credit from negative line items (credits are negative in Stripe) + var prorationCredit = invoicePreview.Lines?.Data? + .Where(line => line.Amount < 0) + .Sum(line => Math.Abs(line.Amount)) ?? 0; // Return the credit as positive number + + return Convert.ToDecimal(prorationCredit) / 100; + } + + private static decimal GetNewPlanProratedAmountFromInvoice(Invoice invoicePreview) + { + // The target plan's prorated upgrade amount should be the only positive-valued line item + var proratedTotal = invoicePreview.Lines?.Data? + .Where(line => line.Amount > 0) + .Sum(line => line.Amount) ?? 0; + + return Convert.ToDecimal(proratedTotal) / 100; + } + + private static int CalculateNewPlanProratedMonths(DateTime invoicePeriodEnd, DateTime currentPeriodEnd) + { + var daysInProratedPeriod = (currentPeriodEnd - invoicePeriodEnd).TotalDays; + + // Round to nearest month (30-day periods) + // 1-14 days = 1 month, 15-44 days = 1 month, 45-74 days = 2 months, etc. + // Minimum is always 1 month (never returns 0) + // Use MidpointRounding.AwayFromZero to round 0.5 up to 1 + var months = (int)Math.Round(daysInProratedPeriod / 30, MidpointRounding.AwayFromZero); + return Math.Max(1, months); + } +} diff --git a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs index 176c77bf57..219f450f1d 100644 --- a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Core.Utilities; @@ -29,6 +30,7 @@ public interface IUpdatePremiumStorageCommand } public class UpdatePremiumStorageCommand( + IBraintreeService braintreeService, IStripeAdapter stripeAdapter, IUserService userService, IPricingClient pricingClient, @@ -49,7 +51,10 @@ public class UpdatePremiumStorageCommand( // Fetch all premium plans and the user's subscription to find which plan they're on var premiumPlans = await pricingClient.ListPremiumPlans(); - var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId); + var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["customer"] + }); // Find the password manager subscription item (seat, not storage) and match it to a plan var passwordManagerItem = subscription.Items.Data.FirstOrDefault(i => @@ -127,13 +132,41 @@ public class UpdatePremiumStorageCommand( }); } - var subscriptionUpdateOptions = new SubscriptionUpdateOptions - { - Items = subscriptionItemOptions, - ProrationBehavior = ProrationBehavior.AlwaysInvoice - }; + var usingPayPal = subscription.Customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); - await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, subscriptionUpdateOptions); + if (usingPayPal) + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.CreateProrations + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + + var draftInvoice = await stripeAdapter.CreateInvoiceAsync(new InvoiceCreateOptions + { + Customer = subscription.CustomerId, + Subscription = subscription.Id, + AutoAdvance = false, + CollectionMethod = CollectionMethod.ChargeAutomatically + }); + + var finalizedInvoice = await stripeAdapter.FinalizeInvoiceAsync(draftInvoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false, Expand = ["customer"] }); + + await braintreeService.PayInvoice(new UserId(user.Id), finalizedInvoice); + } + else + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.AlwaysInvoice + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + } // Update the user's max storage user.MaxStorageGb = maxStorageGb; diff --git a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs index 81bc5c9e2c..803674120a 100644 --- a/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpgradePremiumToOrganizationCommand.cs @@ -2,7 +2,6 @@ using Bit.Core.Billing.Commands; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Enums; -using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Entities; @@ -28,12 +27,14 @@ public interface IUpgradePremiumToOrganizationCommand /// The name for the new organization. /// The encrypted organization key for the owner. /// The target organization plan type to upgrade to. + /// The billing address for tax calculation. /// A billing command result indicating success or failure with appropriate error details. Task> Run( User user, string organizationName, string key, - PlanType targetPlanType); + PlanType targetPlanType, + Payment.Models.BillingAddress billingAddress); } public class UpgradePremiumToOrganizationCommand( @@ -51,7 +52,8 @@ public class UpgradePremiumToOrganizationCommand( User user, string organizationName, string key, - PlanType targetPlanType) => HandleAsync(async () => + PlanType targetPlanType, + Payment.Models.BillingAddress billingAddress) => HandleAsync(async () => { // Validate that the user has an active Premium subscription if (user is not { Premium: true, GatewaySubscriptionId: not null and not "" }) @@ -74,7 +76,7 @@ public class UpgradePremiumToOrganizationCommand( if (passwordManagerItem == null) { - return new BadRequest("Premium subscription item not found."); + return new BadRequest("Premium subscription password manager item not found."); } var usersPremiumPlan = premiumPlans.First(p => p.Seat.StripePriceId == passwordManagerItem.Price.Id); @@ -85,20 +87,10 @@ public class UpgradePremiumToOrganizationCommand( // Build the list of subscription item updates var subscriptionItemOptions = new List(); - // Delete the user's specific password manager item - subscriptionItemOptions.Add(new SubscriptionItemOptions - { - Id = passwordManagerItem.Id, - Deleted = true - }); - // Delete the storage item if it exists for this user's plan var storageItem = currentSubscription.Items.Data.FirstOrDefault(i => i.Price.Id == usersPremiumPlan.Storage.StripePriceId); - // Capture the previous additional storage quantity for potential revert - var previousAdditionalStorage = storageItem?.Quantity ?? 0; - if (storageItem != null) { subscriptionItemOptions.Add(new SubscriptionItemOptions @@ -113,6 +105,7 @@ public class UpgradePremiumToOrganizationCommand( { subscriptionItemOptions.Add(new SubscriptionItemOptions { + Id = passwordManagerItem.Id, Price = targetPlan.PasswordManager.StripePlanId, Quantity = 1 }); @@ -121,6 +114,7 @@ public class UpgradePremiumToOrganizationCommand( { subscriptionItemOptions.Add(new SubscriptionItemOptions { + Id = passwordManagerItem.Id, Price = targetPlan.PasswordManager.StripeSeatPlanId, Quantity = seats }); @@ -133,14 +127,12 @@ public class UpgradePremiumToOrganizationCommand( var subscriptionUpdateOptions = new SubscriptionUpdateOptions { Items = subscriptionItemOptions, - ProrationBehavior = StripeConstants.ProrationBehavior.None, + ProrationBehavior = StripeConstants.ProrationBehavior.AlwaysInvoice, + BillingCycleAnchor = SubscriptionBillingCycleAnchor.Unchanged, + AutomaticTax = new SubscriptionAutomaticTaxOptions { Enabled = true }, Metadata = new Dictionary { [StripeConstants.MetadataKeys.OrganizationId] = organizationId.ToString(), - [StripeConstants.MetadataKeys.PreviousPremiumPriceId] = usersPremiumPlan.Seat.StripePriceId, - [StripeConstants.MetadataKeys.PreviousPeriodEndDate] = currentSubscription.GetCurrentPeriodEnd()?.ToString("O") ?? string.Empty, - [StripeConstants.MetadataKeys.PreviousAdditionalStorage] = previousAdditionalStorage.ToString(), - [StripeConstants.MetadataKeys.PreviousPremiumUserId] = user.Id.ToString(), [StripeConstants.MetadataKeys.UserId] = string.Empty // Remove userId to unlink subscription from User } }; @@ -152,7 +144,7 @@ public class UpgradePremiumToOrganizationCommand( Name = organizationName, BillingEmail = user.Email, PlanType = targetPlan.Type, - Seats = (short)seats, + Seats = seats, MaxCollections = targetPlan.PasswordManager.MaxCollections, MaxStorageGb = targetPlan.PasswordManager.BaseStorageGb, UsePolicies = targetPlan.HasPolicies, @@ -182,6 +174,16 @@ public class UpgradePremiumToOrganizationCommand( GatewaySubscriptionId = currentSubscription.Id }; + // Update customer billing address for tax calculation + await stripeAdapter.UpdateCustomerAsync(user.GatewayCustomerId, new CustomerUpdateOptions + { + Address = new AddressOptions + { + Country = billingAddress.Country, + PostalCode = billingAddress.PostalCode + } + }); + // Update the subscription in Stripe await stripeAdapter.UpdateSubscriptionAsync(currentSubscription.Id, subscriptionUpdateOptions); diff --git a/src/Core/Billing/Premium/Models/PremiumUpgradeProration.cs b/src/Core/Billing/Premium/Models/PremiumUpgradeProration.cs new file mode 100644 index 0000000000..d8acaa3170 --- /dev/null +++ b/src/Core/Billing/Premium/Models/PremiumUpgradeProration.cs @@ -0,0 +1,36 @@ +namespace Bit.Core.Billing.Premium.Models; + +/// +/// Represents the proration details for upgrading a Premium user subscription to an Organization plan. +/// +public class PremiumUpgradeProration +{ + /// + /// The prorated cost for the new organization plan, calculated from now until the end of the current billing period. + /// This represents what the user will pay for the upgraded plan for the remainder of the period. + /// + public decimal NewPlanProratedAmount { get; set; } + + /// + /// The credit amount for the unused portion of the current Premium subscription. + /// This credit is applied against the cost of the new organization plan. + /// + public decimal Credit { get; set; } + + /// + /// The tax amount calculated for the upgrade transaction. + /// + public decimal Tax { get; set; } + + /// + /// The total amount due for the upgrade after applying the credit and adding tax. + /// + public decimal Total { get; set; } + + /// + /// The number of months the user will be charged for the new organization plan in the prorated billing period. + /// Calculated by rounding the days remaining in the current billing cycle to the nearest month. + /// Minimum value is 1 month (never returns 0). + /// + public int NewPlanProratedMonths { get; set; } +} diff --git a/src/Core/Billing/Pricing/IPricingClient.cs b/src/Core/Billing/Pricing/IPricingClient.cs index 18588ae432..755a121832 100644 --- a/src/Core/Billing/Pricing/IPricingClient.cs +++ b/src/Core/Billing/Pricing/IPricingClient.cs @@ -1,7 +1,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Exceptions; using Bit.Core.Models.StaticStore; -using Bit.Core.Utilities; namespace Bit.Core.Billing.Pricing; @@ -12,8 +11,7 @@ public interface IPricingClient { // TODO: Rename with Organization focus. /// - /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, - /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . + /// Retrieve a Bitwarden plan by its from the Bitwarden Pricing Service. /// /// The type of plan to retrieve. /// A Bitwarden record or null in the case the plan could not be found or the method was executed from a self-hosted instance. @@ -22,8 +20,7 @@ public interface IPricingClient // TODO: Rename with Organization focus. /// - /// Retrieve a Bitwarden plan by its . If the feature flag 'use-pricing-service' is enabled, - /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . + /// Retrieve a Bitwarden plan by its from the Bitwarden Pricing Service. /// /// The type of plan to retrieve. /// A Bitwarden record. @@ -33,8 +30,7 @@ public interface IPricingClient // TODO: Rename with Organization focus. /// - /// Retrieve all the Bitwarden plans. If the feature flag 'use-pricing-service' is enabled, - /// this will trigger a request to the Bitwarden Pricing Service. Otherwise, it will use the existing . + /// Retrieve all Bitwarden plans from the Pricing Service. /// /// A list of Bitwarden records or an empty list in the case the method is executed from a self-hosted instance. /// Thrown when the request to the Pricing Service fails unexpectedly. diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs index 5ec732920e..12ea3d5a7c 100644 --- a/src/Core/Billing/Services/IStripeAdapter.cs +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -24,6 +24,7 @@ public interface IStripeAdapter Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); Task GetInvoiceAsync(string id, InvoiceGetOptions options); Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoiceAsync(InvoiceCreateOptions options); Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); Task> SearchInvoiceAsync(InvoiceSearchOptions options); Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs index cdc7645042..5b90500021 100644 --- a/src/Core/Billing/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -116,6 +116,9 @@ public class StripeAdapter : IStripeAdapter return invoices; } + public Task CreateInvoiceAsync(InvoiceCreateOptions options) => + _invoiceService.CreateAsync(options); + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => _invoiceService.CreatePreviewAsync(options); diff --git a/src/Core/Billing/Subscriptions/Entities/SubscriptionDiscount.cs b/src/Core/Billing/Subscriptions/Entities/SubscriptionDiscount.cs new file mode 100644 index 0000000000..03165f6b05 --- /dev/null +++ b/src/Core/Billing/Subscriptions/Entities/SubscriptionDiscount.cs @@ -0,0 +1,48 @@ +#nullable enable + +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Enums; +using Bit.Core.Entities; +using Bit.Core.Utilities; + +namespace Bit.Core.Billing.Subscriptions.Entities; + +public class SubscriptionDiscount : ITableObject, IRevisable, IValidatableObject +{ + public Guid Id { get; set; } + [MaxLength(50)] + public string StripeCouponId { get; set; } = null!; + public ICollection? StripeProductIds { get; set; } + public decimal? PercentOff { get; set; } + public long? AmountOff { get; set; } + [MaxLength(10)] + public string? Currency { get; set; } + [MaxLength(20)] + public string Duration { get; set; } = null!; + public int? DurationInMonths { get; set; } + [MaxLength(100)] + public string? Name { get; set; } + public DateTime StartDate { get; set; } + public DateTime EndDate { get; set; } + public DiscountAudienceType AudienceType { get; set; } + public DateTime CreationDate { get; set; } = DateTime.UtcNow; + public DateTime RevisionDate { get; set; } = DateTime.UtcNow; + + public void SetNewId() + { + if (Id == default) + { + Id = CoreHelpers.GenerateComb(); + } + } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (EndDate < StartDate) + { + yield return new ValidationResult( + "EndDate must be greater than or equal to StartDate.", + new[] { nameof(EndDate) }); + } + } +} diff --git a/src/Core/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQuery.cs b/src/Core/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQuery.cs index cd7fa91fff..51c51bd7b2 100644 --- a/src/Core/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQuery.cs +++ b/src/Core/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQuery.cs @@ -14,6 +14,7 @@ namespace Bit.Core.Billing.Subscriptions.Queries; using static StripeConstants; using static Utilities; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; public interface IGetBitwardenSubscriptionQuery { @@ -107,11 +108,28 @@ public class GetBitwardenSubscriptionQuery( var (cartLevelDiscount, productLevelDiscounts) = GetStripeDiscounts(subscription); + var availablePlan = plans.First(plan => plan.Available); + var onCurrentPricing = passwordManagerSeatsItem.Price.Id == availablePlan.Seat.StripePriceId; + + decimal seatCost; + decimal estimatedTax; + + if (onCurrentPricing) + { + seatCost = GetCost(passwordManagerSeatsItem); + estimatedTax = await EstimatePremiumTaxAsync(subscription); + } + else + { + seatCost = availablePlan.Seat.Price; + estimatedTax = await EstimatePremiumTaxAsync(subscription, plans, availablePlan); + } + var passwordManagerSeats = new CartItem { TranslationKey = "premiumMembership", Quantity = passwordManagerSeatsItem.Quantity, - Cost = GetCost(passwordManagerSeatsItem), + Cost = seatCost, Discount = productLevelDiscounts.FirstOrDefault(discount => discount.AppliesTo(passwordManagerSeatsItem)) }; @@ -125,8 +143,6 @@ public class GetBitwardenSubscriptionQuery( } : null; - var estimatedTax = await EstimateTaxAsync(subscription); - return new Cart { PasswordManager = new PasswordManagerCartItems @@ -142,15 +158,45 @@ public class GetBitwardenSubscriptionQuery( #region Utilities - private async Task EstimateTaxAsync(Subscription subscription) + private async Task EstimatePremiumTaxAsync( + Subscription subscription, + List? plans = null, + PremiumPlan? availablePlan = null) { try { - var invoice = await stripeAdapter.CreateInvoicePreviewAsync(new InvoiceCreatePreviewOptions + var options = new InvoiceCreatePreviewOptions { - Customer = subscription.Customer.Id, - Subscription = subscription.Id - }); + Customer = subscription.Customer.Id + }; + + if (plans != null && availablePlan != null) + { + options.AutomaticTax = new InvoiceAutomaticTaxOptions + { + Enabled = subscription.AutomaticTax?.Enabled ?? false + }; + + options.SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = subscription.Items.Select(item => + { + var isSeatItem = plans.Any(plan => plan.Seat.StripePriceId == item.Price.Id); + + return new InvoiceSubscriptionDetailsItemOptions + { + Price = isSeatItem ? availablePlan.Seat.StripePriceId : item.Price.Id, + Quantity = item.Quantity + }; + }).ToList() + }; + } + else + { + options.Subscription = subscription.Id; + } + + var invoice = await stripeAdapter.CreateInvoicePreviewAsync(options); return GetCost(invoice.TotalTaxes); } diff --git a/src/Core/Billing/Subscriptions/Repositories/ISubscriptionDiscountRepository.cs b/src/Core/Billing/Subscriptions/Repositories/ISubscriptionDiscountRepository.cs new file mode 100644 index 0000000000..93accbd9e7 --- /dev/null +++ b/src/Core/Billing/Subscriptions/Repositories/ISubscriptionDiscountRepository.cs @@ -0,0 +1,23 @@ +#nullable enable + +using Bit.Core.Billing.Subscriptions.Entities; +using Bit.Core.Repositories; + +namespace Bit.Core.Billing.Subscriptions.Repositories; + +public interface ISubscriptionDiscountRepository : IRepository +{ + /// + /// Retrieves all active subscription discounts that are currently within their valid date range. + /// A discount is considered active if the current UTC date falls between StartDate (inclusive) and EndDate (inclusive). + /// + /// A collection of active subscription discounts. + Task> GetActiveDiscountsAsync(); + + /// + /// Retrieves a subscription discount by its Stripe coupon ID. + /// + /// The Stripe coupon ID to search for. + /// The subscription discount if found; otherwise, null. + Task GetByStripeCouponIdAsync(string stripeCouponId); +} diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6f42778b6b..e5148795f4 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -141,30 +141,33 @@ public static class FeatureFlagKeys public const string AutomaticConfirmUsers = "pm-19934-auto-confirm-organization-users"; public const string PM23845_VNextApplicationCache = "pm-24957-refactor-memory-application-cache"; public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration"; - public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud"; + public const string DefaultUserCollectionRestore = "pm-30883-my-items-restored-users"; public const string PremiumAccessQuery = "pm-29495-refactor-premium-interface"; + public const string RefactorMembersComponent = "pm-29503-refactor-members-inheritance"; + public const string BulkReinviteUI = "pm-28416-bulk-reinvite-ux-improvements"; /* Architecture */ public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1"; public const string DesktopMigrationMilestone2 = "desktop-ui-migration-milestone-2"; public const string DesktopMigrationMilestone3 = "desktop-ui-migration-milestone-3"; + public const string DesktopMigrationMilestone4 = "desktop-ui-migration-milestone-4"; /* Auth Team */ public const string TwoFactorExtensionDataPersistence = "pm-9115-two-factor-extension-data-persistence"; - public const string EmailVerification = "email-verification"; public const string BrowserExtensionLoginApproval = "pm-14938-browser-extension-login-approvals"; public const string SetInitialPasswordRefactor = "pm-16117-set-initial-password-refactor"; public const string ChangeExistingPasswordRefactor = "pm-16117-change-existing-password-refactor"; public const string Otp6Digits = "pm-18612-otp-6-digits"; public const string PM24579_PreventSsoOnExistingNonCompliantUsers = "pm-24579-prevent-sso-on-existing-non-compliant-users"; public const string DisableAlternateLoginMethods = "pm-22110-disable-alternate-login-methods"; - public const string MJMLBasedEmailTemplates = "mjml-based-email-templates"; + public const string PM2035PasskeyUnlock = "pm-2035-passkey-unlock"; public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; public const string OrganizationConfirmationEmail = "pm-28402-update-confirmed-to-org-email-template"; public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; - public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; public const string PrefetchPasswordPrelogin = "pm-23801-prefetch-password-prelogin"; + public const string SafariAccountSwitching = "pm-5594-safari-account-switching"; public const string PM27086_UpdateAuthenticationApisForInputPassword = "pm-27086-update-authentication-apis-for-input-password"; + public const string PM27044_UpdateRegistrationApis = "pm-27044-update-registration-apis"; /* Autofill Team */ public const string SSHAgent = "ssh-agent"; @@ -174,6 +177,7 @@ public static class FeatureFlagKeys public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public const string WindowsDesktopAutotype = "windows-desktop-autotype"; public const string WindowsDesktopAutotypeGA = "windows-desktop-autotype-ga"; + public const string NotificationUndeterminedCipherScenarioLogic = "undetermined-cipher-scenario-logic"; /* Billing Team */ public const string TrialPayment = "PM-8163-trial-payment"; @@ -221,9 +225,10 @@ public static class FeatureFlagKeys /* Platform Team */ public const string WebPush = "web-push"; - public const string IpcChannelFramework = "ipc-channel-framework"; + public const string ContentScriptIpcFramework = "content-script-ipc-channel-framework"; public const string PushNotificationsWhenLocked = "pm-19388-push-notifications-when-locked"; public const string PushNotificationsWhenInactive = "pm-25130-receive-push-notifications-for-inactive-users"; + public const string WebAuthnRelatedOrigins = "pm-30529-webauthn-related-origins"; /* Tools Team */ /// @@ -248,6 +253,8 @@ public static class FeatureFlagKeys public const string BrowserPremiumSpotlight = "pm-23384-browser-premium-spotlight"; public const string MigrateMyVaultToMyItems = "pm-20558-migrate-myvault-to-myitems"; public const string PM27632_CipherCrudOperationsToSdk = "pm-27632-cipher-crud-operations-to-sdk"; + public const string PM30521_AutofillButtonViewLoginScreen = "pm-30521-autofill-button-view-login-screen"; + public const string PM29438_WelcomeDialogWithExtensionPrompt = "pm-29438-welcome-dialog-with-extension-prompt"; /* Innovation Team */ public const string ArchiveVaultItems = "pm-19148-innovation-archive"; @@ -255,6 +262,8 @@ public static class FeatureFlagKeys /* DIRT Team */ public const string EventManagementForDataDogAndCrowdStrike = "event-management-for-datadog-and-crowdstrike"; public const string EventDiagnosticLogging = "pm-27666-siem-event-log-debugging"; + public const string EventManagementForHuntress = "event-management-for-huntress"; + public const string Milestone11AppPageImprovements = "pm-30538-dirt-milestone-11-app-page-improvements"; /* UIF Team */ public const string RouterFocusManagement = "router-focus-management"; diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index a423d9377d..54a8a0483f 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -62,7 +62,7 @@ - + diff --git a/src/Core/Entities/User.cs b/src/Core/Entities/User.cs index 669e32bcbe..422dc37c6e 100644 --- a/src/Core/Entities/User.cs +++ b/src/Core/Entities/User.cs @@ -7,8 +7,6 @@ using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; using Microsoft.AspNetCore.Identity; -#nullable enable - namespace Bit.Core.Entities; public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFactorProvidersUser @@ -51,7 +49,7 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac public string? Key { get; set; } /// /// The raw public key, without a signature from the user's signature key. - /// + /// public string? PublicKey { get; set; } /// /// User key wrapped private key. @@ -107,6 +105,8 @@ public class User : ITableObject, IStorableSubscriber, IRevisable, ITwoFac public DateTime? LastKeyRotationDate { get; set; } public DateTime? LastEmailChangeDate { get; set; } public bool VerifyDevices { get; set; } = true; + // PM-28827 Uncomment below line. + // public string? MasterPasswordSalt { get; set; } public string GetMasterPasswordSalt() { diff --git a/src/Api/KeyManagement/Models/Requests/KdfRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/KdfRequestModel.cs similarity index 59% rename from src/Api/KeyManagement/Models/Requests/KdfRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/KdfRequestModel.cs index 904304a633..edcd7f760f 100644 --- a/src/Api/KeyManagement/Models/Requests/KdfRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/KdfRequestModel.cs @@ -1,10 +1,11 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.Enums; using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.Utilities; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; -public class KdfRequestModel +public class KdfRequestModel : IValidatableObject { [Required] public required KdfType KdfType { get; init; } @@ -23,4 +24,10 @@ public class KdfRequestModel Parallelism = Parallelism }; } + + public IEnumerable Validate(ValidationContext validationContext) + { + // Generic per-request KDF validation for any request model embedding KdfRequestModel + return KdfSettingsValidator.Validate(ToData()); + } } diff --git a/src/Api/KeyManagement/Models/Requests/MasterPasswordAuthenticationDataRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/MasterPasswordAuthenticationDataRequestModel.cs similarity index 71% rename from src/Api/KeyManagement/Models/Requests/MasterPasswordAuthenticationDataRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/MasterPasswordAuthenticationDataRequestModel.cs index 4f70a1135f..04c22cc3a6 100644 --- a/src/Api/KeyManagement/Models/Requests/MasterPasswordAuthenticationDataRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/MasterPasswordAuthenticationDataRequestModel.cs @@ -1,8 +1,12 @@ using System.ComponentModel.DataAnnotations; using Bit.Core.KeyManagement.Models.Data; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; +/// +/// Use this datatype when interfacing with requests to create a separation of concern. +/// See to use for commands, queries, services. +/// public class MasterPasswordAuthenticationDataRequestModel { public required KdfRequestModel Kdf { get; init; } diff --git a/src/Api/KeyManagement/Models/Requests/MasterPasswordUnlockDataRequestModel.cs b/src/Core/KeyManagement/Models/Api/Request/MasterPasswordUnlockDataRequestModel.cs similarity index 71% rename from src/Api/KeyManagement/Models/Requests/MasterPasswordUnlockDataRequestModel.cs rename to src/Core/KeyManagement/Models/Api/Request/MasterPasswordUnlockDataRequestModel.cs index e1d7863cae..8d7df86374 100644 --- a/src/Api/KeyManagement/Models/Requests/MasterPasswordUnlockDataRequestModel.cs +++ b/src/Core/KeyManagement/Models/Api/Request/MasterPasswordUnlockDataRequestModel.cs @@ -2,8 +2,12 @@ using Bit.Core.KeyManagement.Models.Data; using Bit.Core.Utilities; -namespace Bit.Api.KeyManagement.Models.Requests; +namespace Bit.Core.KeyManagement.Models.Api.Request; +/// +/// Use this datatype when interfacing with requests to create a separation of concern. +/// See to use for commands, queries, services. +/// public class MasterPasswordUnlockDataRequestModel { public required KdfRequestModel Kdf { get; init; } diff --git a/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs index 536347cea9..9656c8a68b 100644 --- a/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs +++ b/src/Core/KeyManagement/Models/Api/Response/UserDecryptionResponseModel.cs @@ -1,4 +1,7 @@ -namespace Bit.Core.KeyManagement.Models.Api.Response; +using System.Text.Json.Serialization; +using Bit.Core.Auth.Models.Api.Response; + +namespace Bit.Core.KeyManagement.Models.Api.Response; public class UserDecryptionResponseModel { @@ -6,4 +9,10 @@ public class UserDecryptionResponseModel /// Returns the unlock data when the user has a master password that can be used to decrypt their vault. /// public MasterPasswordUnlockResponseModel? MasterPasswordUnlock { get; set; } + + /// + /// Gets or sets the WebAuthn PRF decryption keys. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public WebAuthnPrfDecryptionOption[]? WebAuthnPrfOptions { get; set; } } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs index 1bc7006cef..6e53dfa744 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordAuthenticationData.cs @@ -1,8 +1,13 @@ using Bit.Core.Entities; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Request; namespace Bit.Core.KeyManagement.Models.Data; +/// +/// Use this datatype when interfacing with commands, queries, services to create a separation of concern. +/// See to use for requests. +/// public class MasterPasswordAuthenticationData { public required KdfSettings Kdf { get; init; } diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs index ad3a0b692b..b79ce8bce1 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockAndAuthenticationData.cs @@ -1,5 +1,4 @@ -#nullable enable -using Bit.Core.Entities; +using Bit.Core.Entities; using Bit.Core.Enums; namespace Bit.Core.KeyManagement.Models.Data; diff --git a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs index cb18ed2a78..f8139cba99 100644 --- a/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs +++ b/src/Core/KeyManagement/Models/Data/MasterPasswordUnlockData.cs @@ -1,8 +1,13 @@ using Bit.Core.Entities; using Bit.Core.Exceptions; +using Bit.Core.KeyManagement.Models.Api.Request; namespace Bit.Core.KeyManagement.Models.Data; +/// +/// Use this datatype when interfacing with commands, queries, services to create a separation of concern. +/// See to use for requests. +/// public class MasterPasswordUnlockData { public required KdfSettings Kdf { get; init; } diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.html.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.html.hbs index f10c47c78f..18a4f93ac5 100644 --- a/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.html.hbs +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.html.hbs @@ -1,18 +1,23 @@ {{#>TitleContactUsHtmlLayout}} - + + + -
- Here's what that means: -
    -
  • Your Bitwarden account is owned by {{OrganizationName}}
  • -
  • Your administrators can delete your account at any time
  • -
  • You cannot leave the organization
  • +
+

An {{OrganizationName}} admin has claimed the domain @{{{DomainName}}}. Your email address {{{UserEmail}}} matches this, so your Bitwarden account is now managed by {{OrganizationName}}.

+
+

What this means for you

+
    +
  • Your day-to-day use of Bitwarden remains the same.
  • +
  • Only store work-related items in your {{OrganizationName}} vault.
  • +
  • {{OrganizationName}} admins now manage your account, meaning they can revoke or delete your account.
- For more information, please refer to the following help article: Claimed Accounts + +

For more information, please refer to the following help article: Claimed accounts

diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.text.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.text.hbs index b3041a21e9..4a87d30c34 100644 --- a/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.text.hbs +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/DomainClaimedByOrganization.text.hbs @@ -1,7 +1,8 @@ -As a member of {{OrganizationName}}, your Bitwarden account is claimed and owned by your organization. +An {{OrganizationName}} admin has claimed the domain @{{{DomainName}}}. Your email address {{{UserEmail}}} matches this, so your Bitwarden account is now managed by {{OrganizationName}}. -Here's what that means: -- Your administrators can delete your account at any time -- You cannot leave the organization +What this means for you: +- Your day-to-day use of Bitwarden remains the same. +- Only store work-related items in your {{OrganizationName}} vault. +- {{OrganizationName}} admins now manage your account, meaning they can revoke or delete your account. -For more information, please refer to the following help article: Claimed Accounts (https://bitwarden.com/help/claimed-accounts) +For more information, please refer to the following help article: Claimed accounts (https://bitwarden.com/help/claimed-accounts) diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs index 5bf1f24218..ec18f04af3 100644 --- a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.html.hbs @@ -1,28 +1,691 @@ -{{#>FullHtmlLayout}} - - - - - - - - - - - - - -
- Verify your email to access this Bitwarden Send. -
-
- Your verification code is: {{Token}} -
-
- This code can only be used once and expires in 5 minutes. After that you'll need to verify your email again. -
-
-
- {{TheDate}} at {{TheTime}} {{TimeZone}} -
-{{/FullHtmlLayout}} \ No newline at end of file + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + +
+ + + + + + + +
+ + + + + + + + +
+ + + + + +
+ + + + + + + +
+ + +
+ + + + + + + + + + + + + +
+ + + + + + + +
+ + + +
+ +
+ +

+ Verify your email to access this Bitwarden Send +

+ +
+ +
+ + + +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ +
+ +
+ + +
+ +
+ + + + + +
+ + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ + + + + + + +
+ + +
+ + + + + + + +
+ + + + + + + + + + + + + + + + + + + + + +
+ +
Your verification code is:
+ +
+ +
{{Token}}
+ +
+ +
+ +
+ +
This code expires in {{Expiry}} minutes. After that, you'll need + to verify your email again.
+ +
+ +
+ +
+ + +
+ +
+ + + + + +
+ + + + + + + +
+ + +
+ + + + + + + +
+ + + + + + + + + +
+ +

+ Bitwarden Send transmits sensitive, temporary information to + others easily and securely. Learn more about + Bitwarden Send + or + sign up + to try it today. +

+ +
+ +
+ +
+ + +
+ +
+ + + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + + +
+ + + + + + + +
+ + +
+ + + + + + + + + +
+ +

+ Learn more about Bitwarden +

+ Find user guides, product documentation, and videos on the + Bitwarden Help Center.
+ +
+ +
+ + + +
+ + + + + + + + + +
+ +
+ + +
+ +
+ + + +
+ +
+ + + + + + + + + +
+ + + + + + + +
+ + +
+ + + + + + + + + + + + + +
+ + + + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + + + + + + + + +
+ + + + + + +
+ + + +
+
+ + + +
+ +

+ © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + Barbara, CA, USA +

+

+ Always confirm you are on a trusted Bitwarden domain before logging + in:
+ bitwarden.com | + Learn why we include this +

+ +
+ +
+ + +
+ +
+ + + + + +
+ + + + \ No newline at end of file diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs index f83008c30b..7c9c1db527 100644 --- a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs +++ b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmail.text.hbs @@ -3,7 +3,7 @@ Verify your email to access this Bitwarden Send. Your verification code is: {{Token}} -This code can only be used once and expires in 5 minutes. After that you'll need to verify your email again. +This code can only be used once and expires in {{Expiry}} minutes. After that you'll need to verify your email again. -Date : {{TheDate}} at {{TheTime}} {{TimeZone}} -{{/BasicTextLayout}} \ No newline at end of file +Bitwarden Send transmits sensitive, temporary information to others easily and securely. Learn more about Bitwarden Send or sign up to try it today. +{{/BasicTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs b/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs deleted file mode 100644 index 7c9c1db527..0000000000 --- a/src/Core/MailTemplates/Handlebars/Auth/SendAccessEmailOtpEmailv2.text.hbs +++ /dev/null @@ -1,9 +0,0 @@ -{{#>BasicTextLayout}} -Verify your email to access this Bitwarden Send. - -Your verification code is: {{Token}} - -This code can only be used once and expires in {{Expiry}} minutes. After that you'll need to verify your email again. - -Bitwarden Send transmits sensitive, temporary information to others easily and securely. Learn more about Bitwarden Send or sign up to try it today. -{{/BasicTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/Layouts/TitleContactUs.html.hbs b/src/Core/MailTemplates/Handlebars/Layouts/TitleContactUs.html.hbs index ed0d7cd9af..1fa6c014f0 100644 --- a/src/Core/MailTemplates/Handlebars/Layouts/TitleContactUs.html.hbs +++ b/src/Core/MailTemplates/Handlebars/Layouts/TitleContactUs.html.hbs @@ -1,11 +1,11 @@ -{{#>FullUpdatedHtmlLayout}} +{{#>FullUpdatedHtmlLayout}} @@ -271,12 +271,12 @@
- {{TitleFirst}}{{TitleSecondBold}}{{TitleThird}} + {{{TitleFirst}}}{{TitleSecondBold}}{{TitleThird}}
diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs index b2b957f849..4daebc6bbd 100644 --- a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-family-user.html.hbs @@ -867,7 +867,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs index 4cdf153c30..8a2cc4d9c7 100644 --- a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-individual-user.html.hbs @@ -866,7 +866,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs index 5a8dfb7374..8d65da4188 100644 --- a/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs +++ b/src/Core/MailTemplates/Handlebars/MJML/Auth/Onboarding/welcome-org-user.html.hbs @@ -867,7 +867,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/MailTemplates/Mjml/components/footer.mjml b/src/Core/MailTemplates/Mjml/components/footer.mjml index ddaf3f493b..94d93f4fb2 100644 --- a/src/Core/MailTemplates/Mjml/components/footer.mjml +++ b/src/Core/MailTemplates/Mjml/components/footer.mjml @@ -39,7 +39,7 @@

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/MailTemplates/Mjml/emails/Auth/UserFeatures/EmergencyAccess/emergency-access-remove-grantees.mjml b/src/Core/MailTemplates/Mjml/emails/Auth/UserFeatures/EmergencyAccess/emergency-access-remove-grantees.mjml new file mode 100644 index 0000000000..0dc9f93e45 --- /dev/null +++ b/src/Core/MailTemplates/Mjml/emails/Auth/UserFeatures/EmergencyAccess/emergency-access-remove-grantees.mjml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + The following emergency contacts have been removed from your account: +

    + {{#each RemovedGranteeEmails}} +
  • {{this}}
  • + {{/each}} +
+ Learn more about emergency access. + + + + + + + + + diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml index 092ae303de..06f60e7724 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml @@ -18,7 +18,7 @@ at {{BaseAnnualRenewalPrice}} + tax. - As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml index a460442a7c..defec91f0e 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -17,8 +17,8 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. - As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact diff --git a/src/Core/Models/Data/Organizations/ClaimedUserDomainClaimedEmails.cs b/src/Core/Models/Data/Organizations/ClaimedUserDomainClaimedEmails.cs index 2b73fc1525..f44274c6a5 100644 --- a/src/Core/Models/Data/Organizations/ClaimedUserDomainClaimedEmails.cs +++ b/src/Core/Models/Data/Organizations/ClaimedUserDomainClaimedEmails.cs @@ -2,4 +2,4 @@ namespace Bit.Core.Models.Data.Organizations; -public record ClaimedUserDomainClaimedEmails(IEnumerable EmailList, Organization Organization); +public record ClaimedUserDomainClaimedEmails(IEnumerable EmailList, Organization Organization, string DomainName); diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs index 227613999b..e9d7d406a0 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -202,7 +202,7 @@
-
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
- + -
+
- +
@@ -364,8 +364,8 @@ - -
- + +
@@ -381,13 +381,13 @@
- +
+ - @@ -404,13 +404,13 @@ -
+ - +
- +
+ - @@ -427,13 +427,13 @@ -
+ - +
- +
+ - @@ -450,13 +450,13 @@ -
+ - +
- +
+ - @@ -473,13 +473,13 @@ -
+ - +
- +
+ - @@ -496,13 +496,13 @@ -
+ - +
- +
+ - @@ -519,13 +519,13 @@ -
+ - +
- +
+ - @@ -546,15 +546,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs index 88d64f9acf..9f40c88329 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -1,7 +1,7 @@ Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually at {{BaseAnnualRenewalPrice}} + tax. -As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs index ac6b80993c..d1e1dcec31 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2020Renewal/Families2020RenewalMailView.html.hbs @@ -583,7 +583,7 @@ @@ -270,12 +270,12 @@
+ - +
-

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa +

+ © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa + © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs index 4006c92a63..0798c7dbc8 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs @@ -5,7 +5,7 @@ namespace Bit.Core.Models.Mail.Billing.Renewal.Premium; public class PremiumRenewalMailView : BaseMailView { public required string BaseMonthlyRenewalPrice { get; set; } - public required string DiscountedMonthlyRenewalPrice { get; set; } + public required string DiscountedAnnualRenewalPrice { get; set; } public required string DiscountAmount { get; set; } } diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs index a6b2fda0f7..182a24cde3 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -201,8 +201,8 @@

-
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
+
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
- + -
+
- +
@@ -363,8 +363,8 @@ - -
- + +
@@ -380,13 +380,13 @@
- +
+ - @@ -403,13 +403,13 @@ -
+ - +
- +
+ - @@ -426,13 +426,13 @@ -
+ - +
- +
+ - @@ -449,13 +449,13 @@ -
+ - +
- +
+ - @@ -472,13 +472,13 @@ -
+ - +
- +
+ - @@ -495,13 +495,13 @@ -
+ - +
- +
+ - @@ -518,13 +518,13 @@ -
+ - +
- +
+ - @@ -545,15 +545,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs index 41300d0f96..4b79826f71 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -1,6 +1,6 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. -As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. +As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/ClaimedDomainUserNotificationViewModel.cs b/src/Core/Models/Mail/ClaimedDomainUserNotificationViewModel.cs index fa1ed5ab45..deb5571e96 100644 --- a/src/Core/Models/Mail/ClaimedDomainUserNotificationViewModel.cs +++ b/src/Core/Models/Mail/ClaimedDomainUserNotificationViewModel.cs @@ -6,4 +6,7 @@ namespace Bit.Core.Models.Mail; public class ClaimedDomainUserNotificationViewModel : BaseTitleContactUsMailModel { public string OrganizationName { get; init; } + public string DomainName { get; init; } + public string EmailDomain { get; init; } + public string UserEmail { get; init; } } diff --git a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs index 4ad63bd8d7..9c06ce1709 100644 --- a/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs +++ b/src/Core/OrganizationFeatures/OrganizationSubscriptions/UpgradeOrganizationPlanCommand.cs @@ -5,6 +5,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.OrganizationConnectionConfigs; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -30,6 +31,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand private readonly IGroupRepository _groupRepository; private readonly IStripePaymentService _paymentService; private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IOrganizationConnectionRepository _organizationConnectionRepository; private readonly IServiceAccountRepository _serviceAccountRepository; @@ -45,6 +47,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand IGroupRepository groupRepository, IStripePaymentService paymentService, IPolicyRepository policyRepository, + IPolicyQuery policyQuery, ISsoConfigRepository ssoConfigRepository, IOrganizationConnectionRepository organizationConnectionRepository, IServiceAccountRepository serviceAccountRepository, @@ -59,6 +62,7 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand _groupRepository = groupRepository; _paymentService = paymentService; _policyRepository = policyRepository; + _policyQuery = policyQuery; _ssoConfigRepository = ssoConfigRepository; _organizationConnectionRepository = organizationConnectionRepository; _serviceAccountRepository = serviceAccountRepository; @@ -184,9 +188,8 @@ public class UpgradeOrganizationPlanCommand : IUpgradeOrganizationPlanCommand if (!newPlan.HasResetPassword && organization.UseResetPassword) { - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); - if (resetPasswordPolicy != null && resetPasswordPolicy.Enabled) + var resetPasswordPolicy = await _policyQuery.RunAsync(organization.Id, PolicyType.ResetPassword); + if (resetPasswordPolicy.Enabled) { throw new BadRequestException("Your new plan does not allow the Password Reset feature. " + "Disable your Password Reset policy."); diff --git a/src/Core/Platform/Mail/HandlebarsMailService.cs b/src/Core/Platform/Mail/HandlebarsMailService.cs index d57ca400fd..298e335c9f 100644 --- a/src/Core/Platform/Mail/HandlebarsMailService.cs +++ b/src/Core/Platform/Mail/HandlebarsMailService.cs @@ -1,6 +1,4 @@ -#nullable enable - -using System.Diagnostics; +using System.Diagnostics; using System.Net; using System.Reflection; using System.Text.Json; @@ -13,6 +11,7 @@ using Bit.Core.Auth.Models.Mail; using Bit.Core.Billing.Enums; using Bit.Core.Billing.Models.Mail; using Bit.Core.Entities; +using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations; using Bit.Core.Models.Mail; using Bit.Core.Models.Mail.Auth; @@ -209,26 +208,6 @@ public class HandlebarsMailService : IMailService } public async Task SendSendEmailOtpEmailAsync(string email, string token, string subject) - { - var message = CreateDefaultMessage(subject, email); - var requestDateTime = DateTime.UtcNow; - var model = new DefaultEmailOtpViewModel - { - Token = token, - TheDate = requestDateTime.ToLongDateString(), - TheTime = requestDateTime.ToShortTimeString(), - TimeZone = _utcTimeZoneDisplay, - WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, - SiteName = _globalSettings.SiteName, - }; - await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmail", model); - message.MetaData.Add("SendGridBypassListManagement", true); - // TODO - PM-25380 change to string constant - message.Category = "SendEmailOtp"; - await _mailDeliveryService.SendEmailAsync(message); - } - - public async Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) { var message = CreateDefaultMessage(subject, email); var requestDateTime = DateTime.UtcNow; @@ -242,7 +221,7 @@ public class HandlebarsMailService : IMailService WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, SiteName = _globalSettings.SiteName, }; - await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmailv2", model); + await AddMessageContentAsync(message, "Auth.SendAccessEmailOtpEmail", model); message.MetaData.Add("SendGridBypassListManagement", true); // TODO - PM-25380 change to string constant message.Category = "SendEmailOtp"; @@ -653,16 +632,19 @@ public class HandlebarsMailService : IMailService public async Task SendClaimedDomainUserEmailAsync(ClaimedUserDomainClaimedEmails emailList) { await EnqueueMailAsync(emailList.EmailList.Select(email => - CreateMessage(email, emailList.Organization))); + CreateMessage(email, emailList.Organization, emailList.DomainName))); return; - MailQueueMessage CreateMessage(string emailAddress, Organization org) => - new(CreateDefaultMessage($"Your Bitwarden account is claimed by {org.DisplayName()}", emailAddress), + MailQueueMessage CreateMessage(string emailAddress, Organization org, string domainName) => + new(CreateDefaultMessage($"Important update to your Bitwarden account", emailAddress), "AdminConsole.DomainClaimedByOrganization", new ClaimedDomainUserNotificationViewModel { - TitleFirst = $"Your Bitwarden account is claimed by {org.DisplayName()}", - OrganizationName = CoreHelpers.SanitizeForEmail(org.DisplayName(), false) + TitleFirst = $"Important update to your
Bitwarden account", + OrganizationName = CoreHelpers.SanitizeForEmail(org.DisplayName(), false), + DomainName = domainName, + EmailDomain = emailAddress.Split('@').LastOrDefault() ?? "", + UserEmail = emailAddress }); } @@ -1040,6 +1022,11 @@ public class HandlebarsMailService : IMailService public async Task SendEmergencyAccessInviteEmailAsync(EmergencyAccess emergencyAccess, string name, string token) { + if (string.IsNullOrEmpty(emergencyAccess.Email)) + { + throw new BadRequestException("Emergency Access not valid."); + } + var message = CreateDefaultMessage($"Emergency Access Contact Invite", emergencyAccess.Email); var model = new EmergencyAccessInvitedViewModel { diff --git a/src/Core/Platform/Mail/IMailService.cs b/src/Core/Platform/Mail/IMailService.cs index e21e1a010f..e07e4bad29 100644 --- a/src/Core/Platform/Mail/IMailService.cs +++ b/src/Core/Platform/Mail/IMailService.cs @@ -51,17 +51,15 @@ public interface IMailService Task SendChangeEmailAlreadyExistsEmailAsync(string fromEmail, string toEmail); Task SendChangeEmailEmailAsync(string newEmailAddress, string token); Task SendTwoFactorEmailAsync(string email, string accountEmail, string token, string deviceIp, string deviceType, TwoFactorEmailPurpose purpose); - Task SendSendEmailOtpEmailAsync(string email, string token, string subject); /// - /// has a default expiry of 5 minutes so we set the expiry to that value int he view model. + /// has a default expiry of 5 minutes so we set the expiry to that value in the view model. /// Sends OTP code token to the specified email address. - /// will replace when MJML templates are fully accepted. /// /// Email address to send the OTP to /// Otp code token - /// subject line of the email + /// Subject line of the email /// Task - Task SendSendEmailOtpEmailv2Async(string email, string token, string subject); + Task SendSendEmailOtpEmailAsync(string email, string token, string subject); Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType type, DateTime utcNow, string ip); Task SendNoMasterPasswordHintEmailAsync(string email); Task SendMasterPasswordHintEmailAsync(string email, string hint); diff --git a/src/Core/Platform/Mail/NoopMailService.cs b/src/Core/Platform/Mail/NoopMailService.cs index 7de48e4619..0064058afb 100644 --- a/src/Core/Platform/Mail/NoopMailService.cs +++ b/src/Core/Platform/Mail/NoopMailService.cs @@ -99,11 +99,6 @@ public class NoopMailService : IMailService return Task.FromResult(0); } - public Task SendSendEmailOtpEmailv2Async(string email, string token, string subject) - { - return Task.FromResult(0); - } - public Task SendFailedTwoFactorAttemptEmailAsync(string email, TwoFactorProviderType failedType, DateTime utcNow, string ip) { return Task.FromResult(0); diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index f86147ca7d..2db809e3de 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -45,7 +45,7 @@ public interface ICollectionRepository : IRepository /// Optionally, you can include access relationships for other Groups/Users and the collections. /// Excludes default collections (My Items collections) - used by Admin Console Collections tab. /// - Task> GetManyByOrganizationIdWithPermissionsAsync(Guid organizationId, Guid userId, bool includeAccessRelationships); + Task> GetManySharedByOrganizationIdWithPermissionsAsync(Guid organizationId, Guid userId, bool includeAccessRelationships); /// /// Returns the collection by Id, including permission info for the specified user. @@ -64,11 +64,22 @@ public interface ICollectionRepository : IRepository IEnumerable users, IEnumerable groups); /// - /// Creates default user collections for the specified organization users if they do not already have one. + /// Creates default user collections for the specified organization users. + /// Filters internally to only create collections for users who don't already have one. /// /// The Organization ID. /// The Organization User IDs to create default collections for. /// The encrypted string to use as the default collection name. - /// - Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + + /// + /// Creates default user collections for the specified organization users using bulk insert operations. + /// Use this if you need to create collections for > ~1k users. + /// Filters internally to only create collections for users who don't already have one. + /// + /// The Organization ID. + /// The Organization User IDs to create default collections for. + /// The encrypted string to use as the default collection name. + Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + } diff --git a/src/Core/Services/IBraintreeService.cs b/src/Core/Services/IBraintreeService.cs index 166d285908..d4f5809f41 100644 --- a/src/Core/Services/IBraintreeService.cs +++ b/src/Core/Services/IBraintreeService.cs @@ -1,11 +1,14 @@ using Bit.Core.Billing.Subscriptions.Models; -using Stripe; +using Braintree; namespace Bit.Core.Services; public interface IBraintreeService { + Task GetCustomer( + Stripe.Customer customer); + Task PayInvoice( SubscriberId subscriberId, - Invoice invoice); + Stripe.Invoice invoice); } diff --git a/src/Core/Services/Implementations/BraintreeService.cs b/src/Core/Services/Implementations/BraintreeService.cs index e3630ed888..6ff3f5ce59 100644 --- a/src/Core/Services/Implementations/BraintreeService.cs +++ b/src/Core/Services/Implementations/BraintreeService.cs @@ -1,11 +1,10 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Services; using Bit.Core.Billing.Subscriptions.Models; -using Bit.Core.Exceptions; using Bit.Core.Settings; using Braintree; +using Braintree.Exceptions; using Microsoft.Extensions.Logging; -using Stripe; namespace Bit.Core.Services.Implementations; @@ -18,11 +17,34 @@ public class BraintreeService( IMailService mailService, IStripeAdapter stripeAdapter) : IBraintreeService { - private readonly ConflictException _problemPayingInvoice = new("There was a problem paying for your invoice. Please contact customer support."); + private readonly Exceptions.ConflictException _problemPayingInvoice = new("There was a problem paying for your invoice. Please contact customer support."); + + public async Task GetCustomer( + Stripe.Customer customer) + { + if (!customer.Metadata.TryGetValue(MetadataKeys.BraintreeCustomerId, out var braintreeCustomerId)) + { + return null; + } + + try + { + return await braintreeGateway.Customer.FindAsync(braintreeCustomerId); + } + catch (NotFoundException) + { + logger.LogWarning( + "Stripe customer ({CustomerId}) is linked to a Braintree Customer ({BraintreeCustomerId}) that does not exist.", + customer.Id, + braintreeCustomerId); + + return null; + } + } public async Task PayInvoice( SubscriberId subscriberId, - Invoice invoice) + Stripe.Invoice invoice) { if (invoice.Customer == null) { @@ -93,7 +115,7 @@ public class BraintreeService( return; } - await stripeAdapter.UpdateInvoiceAsync(invoice.Id, new InvoiceUpdateOptions + await stripeAdapter.UpdateInvoiceAsync(invoice.Id, new Stripe.InvoiceUpdateOptions { Metadata = new Dictionary { @@ -102,6 +124,6 @@ public class BraintreeService( } }); - await stripeAdapter.PayInvoiceAsync(invoice.Id, new InvoicePayOptions { PaidOutOfBand = true }); + await stripeAdapter.PayInvoiceAsync(invoice.Id, new Stripe.InvoicePayOptions { PaidOutOfBand = true }); } } diff --git a/src/Core/Services/Implementations/UserService.cs b/src/Core/Services/Implementations/UserService.cs index 64caf1d462..5f87ee85d2 100644 --- a/src/Core/Services/Implementations/UserService.cs +++ b/src/Core/Services/Implementations/UserService.cs @@ -61,7 +61,7 @@ public class UserService : UserManager, IUserService private readonly IEventService _eventService; private readonly IApplicationCacheService _applicationCacheService; private readonly IStripePaymentService _paymentService; - private readonly IPolicyRepository _policyRepository; + private readonly IPolicyQuery _policyQuery; private readonly IPolicyService _policyService; private readonly IFido2 _fido2; private readonly ICurrentContext _currentContext; @@ -98,7 +98,7 @@ public class UserService : UserManager, IUserService IEventService eventService, IApplicationCacheService applicationCacheService, IStripePaymentService paymentService, - IPolicyRepository policyRepository, + IPolicyQuery policyQuery, IPolicyService policyService, IFido2 fido2, ICurrentContext currentContext, @@ -139,7 +139,7 @@ public class UserService : UserManager, IUserService _eventService = eventService; _applicationCacheService = applicationCacheService; _paymentService = paymentService; - _policyRepository = policyRepository; + _policyQuery = policyQuery; _policyService = policyService; _fido2 = fido2; _currentContext = currentContext; @@ -722,9 +722,8 @@ public class UserService : UserManager, IUserService } // Enterprise policy must be enabled - var resetPasswordPolicy = - await _policyRepository.GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); - if (resetPasswordPolicy == null || !resetPasswordPolicy.Enabled) + var resetPasswordPolicy = await _policyQuery.RunAsync(orgId, PolicyType.ResetPassword); + if (!resetPasswordPolicy.Enabled) { throw new BadRequestException("Organization does not have the password reset policy enabled."); } diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 1f4fa6104b..6ccbd1ee85 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -83,7 +83,6 @@ public class GlobalSettings : IGlobalSettings public virtual ILaunchDarklySettings LaunchDarkly { get; set; } = new LaunchDarklySettings(); public virtual string DevelopmentDirectory { get; set; } public virtual IWebPushSettings WebPush { get; set; } = new WebPushSettings(); - public virtual int SendAccessTokenLifetimeInMinutes { get; set; } = 5; public virtual bool EnableEmailVerification { get; set; } public virtual string KdfDefaultHashKey { get; set; } @@ -93,6 +92,7 @@ public class GlobalSettings : IGlobalSettings public virtual string SendDefaultHashKey { get; set; } public virtual string PricingUri { get; set; } public virtual Fido2Settings Fido2 { get; set; } = new Fido2Settings(); + public virtual ICommunicationSettings Communication { get; set; } = new CommunicationSettings(); public string BuildExternalUri(string explicitValue, string name) { @@ -776,4 +776,17 @@ public class GlobalSettings : IGlobalSettings { public HashSet Origins { get; set; } } + + public class CommunicationSettings : ICommunicationSettings + { + public string Bootstrap { get; set; } = "none"; + public ISsoCookieVendorSettings SsoCookieVendor { get; set; } = new SsoCookieVendorSettings(); + } + + public class SsoCookieVendorSettings : ISsoCookieVendorSettings + { + public string IdpLoginUrl { get; set; } + public string CookieName { get; set; } + public string CookieDomain { get; set; } + } } diff --git a/src/Core/Settings/ICommunicationSettings.cs b/src/Core/Settings/ICommunicationSettings.cs new file mode 100644 index 0000000000..26259a8448 --- /dev/null +++ b/src/Core/Settings/ICommunicationSettings.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Settings; + +public interface ICommunicationSettings +{ + string Bootstrap { get; set; } + ISsoCookieVendorSettings SsoCookieVendor { get; set; } +} diff --git a/src/Core/Settings/IGlobalSettings.cs b/src/Core/Settings/IGlobalSettings.cs index c316836d09..7f5323fac0 100644 --- a/src/Core/Settings/IGlobalSettings.cs +++ b/src/Core/Settings/IGlobalSettings.cs @@ -29,4 +29,5 @@ public interface IGlobalSettings IWebPushSettings WebPush { get; set; } GlobalSettings.EventLoggingSettings EventLogging { get; set; } GlobalSettings.WebAuthnSettings WebAuthn { get; set; } + ICommunicationSettings Communication { get; set; } } diff --git a/src/Core/Settings/ISsoCookieVendorSettings.cs b/src/Core/Settings/ISsoCookieVendorSettings.cs new file mode 100644 index 0000000000..a9f2169b13 --- /dev/null +++ b/src/Core/Settings/ISsoCookieVendorSettings.cs @@ -0,0 +1,8 @@ +namespace Bit.Core.Settings; + +public interface ISsoCookieVendorSettings +{ + string IdpLoginUrl { get; set; } + string CookieName { get; set; } + string CookieDomain { get; set; } +} diff --git a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs index fa558f5963..3f856e96fc 100644 --- a/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs +++ b/src/Core/Tools/ImportFeatures/ImportCiphersCommand.cs @@ -74,7 +74,13 @@ public class ImportCiphersCommand : IImportCiphersCommand if (cipher.UserId.HasValue && cipher.Favorite) { - cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":\"true\"}}"; + cipher.Favorites = $"{{\"{cipher.UserId.ToString().ToUpperInvariant()}\":true}}"; + } + + if (cipher.UserId.HasValue && cipher.ArchivedDate.HasValue) + { + cipher.Archives = $"{{\"{cipher.UserId.Value.ToString().ToUpperInvariant()}\":\"" + + $"{cipher.ArchivedDate.Value:yyyy-MM-ddTHH:mm:ss.fffffffZ}\"}}"; } } @@ -135,10 +141,16 @@ public class ImportCiphersCommand : IImportCiphersCommand } } - // Init. ids for ciphers foreach (var cipher in ciphers) { + // Init. ids for ciphers cipher.SetNewId(); + + if (cipher.ArchivedDate.HasValue) + { + cipher.Archives = $"{{\"{importingUserId.ToString().ToUpperInvariant()}\":\"" + + $"{cipher.ArchivedDate.Value:yyyy-MM-ddTHH:mm:ss.fffffffZ}\"}}"; + } } var organizationCollectionsIds = (await _collectionRepository.GetManyByOrganizationIdAsync(org.Id)).Select(c => c.Id).ToList(); diff --git a/src/Core/Tools/Models/Data/SendAuthenticationTypes.cs b/src/Core/Tools/Models/Data/SendAuthenticationTypes.cs index 9ce477ed0c..21d0822c90 100644 --- a/src/Core/Tools/Models/Data/SendAuthenticationTypes.cs +++ b/src/Core/Tools/Models/Data/SendAuthenticationTypes.cs @@ -44,7 +44,7 @@ public record ResourcePassword(string Hash) : SendAuthenticationMethod; /// /// Create a send claim by requesting a one time password (OTP) confirmation code. /// -/// +/// /// The list of email addresses permitted access to the send. /// -public record EmailOtp(string[] Emails) : SendAuthenticationMethod; +public record EmailOtp(string[] emails) : SendAuthenticationMethod; diff --git a/src/Core/Tools/SendFeatures/Commands/Interfaces/INonAnonymousSendCommand.cs b/src/Core/Tools/SendFeatures/Commands/Interfaces/INonAnonymousSendCommand.cs index 58693e619c..5ecf056268 100644 --- a/src/Core/Tools/SendFeatures/Commands/Interfaces/INonAnonymousSendCommand.cs +++ b/src/Core/Tools/SendFeatures/Commands/Interfaces/INonAnonymousSendCommand.cs @@ -47,7 +47,45 @@ public interface INonAnonymousSendCommand /// when the file is confirmed, otherwise /// /// When a file size cannot be confirmed, we assume we're working with a rogue client. The send is deleted out of - /// an abundance of caution. + /// an abundance of caution. /// Task ConfirmFileSize(Send send); + + /// + /// If a File type Send can be downloaded, retrieves the download URL. + /// + /// The this command acts upon + /// The fileId to be downloaded + /// + /// A tuple wrapping the download URL string and indicating whether access was granted + /// + /// + /// This method is intended for authenticated endpoints where authentication has already been validated. + /// Returns when the Send is disabled, MaxAccessCount has been reached, + /// expiration date has passed, or deletion date has been reached. + /// + Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId); + + /// + /// Determines whether a can be accessed based on its current state. + /// + /// The to evaluate for access + /// if the Send can be accessed, otherwise + /// + /// This method checks if the Send is disabled, if MaxAccessCount has been reached, + /// if the expiration date has passed, or if the deletion date has been reached. + /// + static bool SendCanBeAccessed(Send send) + { + var now = DateTime.UtcNow; + if (send.MaxAccessCount.GetValueOrDefault(int.MaxValue) <= send.AccessCount || + send.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < now || + send.Disabled || + send.DeletionDate < now) + { + return false; + } + + return true; + } } diff --git a/src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs b/src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs index 9655d155e6..21ca1ca3fb 100644 --- a/src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs +++ b/src/Core/Tools/SendFeatures/Commands/NonAnonymousSendCommand.cs @@ -27,7 +27,6 @@ public class NonAnonymousSendCommand : INonAnonymousSendCommand public NonAnonymousSendCommand(ISendRepository sendRepository, ISendFileStorageService sendFileStorageService, IPushNotificationService pushNotificationService, - ISendAuthorizationService sendAuthorizationService, ISendValidationService sendValidationService, ISendCoreHelperService sendCoreHelperService, ILogger logger) @@ -181,4 +180,21 @@ public class NonAnonymousSendCommand : INonAnonymousSendCommand return valid; } + public async Task<(string, SendAccessResult)> GetSendFileDownloadUrlAsync(Send send, string fileId) + { + if (send.Type != SendType.File) + { + throw new BadRequestException("Can only get a download URL for a file type of Send"); + } + + if (!INonAnonymousSendCommand.SendCanBeAccessed(send)) + { + return (null, SendAccessResult.Denied); + } + + send.AccessCount++; + await _sendRepository.ReplaceAsync(send); + await _pushNotificationService.PushSyncSendUpdateAsync(send); + return (await _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId), SendAccessResult.Granted); + } } diff --git a/src/Core/Tools/SendFeatures/Queries/SendAuthenticationQuery.cs b/src/Core/Tools/SendFeatures/Queries/SendAuthenticationQuery.cs index 97c2e64dc5..6c7b965ef2 100644 --- a/src/Core/Tools/SendFeatures/Queries/SendAuthenticationQuery.cs +++ b/src/Core/Tools/SendFeatures/Queries/SendAuthenticationQuery.cs @@ -37,8 +37,11 @@ public class SendAuthenticationQuery : ISendAuthenticationQuery SendAuthenticationMethod method = send switch { null => NEVER_AUTHENTICATE, - var s when s.AccessCount >= s.MaxAccessCount => NEVER_AUTHENTICATE, - var s when s.AuthType == AuthType.Email && s.Emails is not null => emailOtp(s.Emails), + var s when s.Disabled => NEVER_AUTHENTICATE, + var s when s.AccessCount >= s.MaxAccessCount.GetValueOrDefault(int.MaxValue) => NEVER_AUTHENTICATE, + var s when s.ExpirationDate.GetValueOrDefault(DateTime.MaxValue) < DateTime.UtcNow => NEVER_AUTHENTICATE, + var s when s.DeletionDate <= DateTime.UtcNow => NEVER_AUTHENTICATE, + var s when s.AuthType == AuthType.Email && s.Emails is not null => EmailOtp(s.Emails), var s when s.AuthType == AuthType.Password && s.Password is not null => new ResourcePassword(s.Password), _ => NOT_AUTHENTICATED }; @@ -46,8 +49,12 @@ public class SendAuthenticationQuery : ISendAuthenticationQuery return method; } - private EmailOtp emailOtp(string emails) + private static EmailOtp EmailOtp(string? emails) { + if (string.IsNullOrWhiteSpace(emails)) + { + return new EmailOtp([]); + } var list = emails.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); return new EmailOtp(list); } diff --git a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs index c545c8b35f..bd987bb396 100644 --- a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs +++ b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -27,6 +28,7 @@ public class SendValidationService : ISendValidationService private readonly GlobalSettings _globalSettings; private readonly ICurrentContext _currentContext; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IPricingClient _pricingClient; @@ -38,7 +40,7 @@ public class SendValidationService : ISendValidationService IUserService userService, IPolicyRequirementQuery policyRequirementQuery, GlobalSettings globalSettings, - + IPricingClient pricingClient, ICurrentContext currentContext) { _userRepository = userRepository; @@ -48,6 +50,7 @@ public class SendValidationService : ISendValidationService _userService = userService; _policyRequirementQuery = policyRequirementQuery; _globalSettings = globalSettings; + _pricingClient = pricingClient; _currentContext = currentContext; } @@ -123,10 +126,19 @@ public class SendValidationService : ISendValidationService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - short limit = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1; - storageBytesRemaining = user.StorageBytesRemaining(limit); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (send.OrganizationId.HasValue) diff --git a/src/Core/Utilities/DomainNameAttribute.cs b/src/Core/Utilities/DomainNameAttribute.cs new file mode 100644 index 0000000000..9b571e96d7 --- /dev/null +++ b/src/Core/Utilities/DomainNameAttribute.cs @@ -0,0 +1,64 @@ +using System.ComponentModel.DataAnnotations; +using System.Text.RegularExpressions; + +namespace Bit.Core.Utilities; + +/// +/// https://bitwarden.atlassian.net/browse/VULN-376 +/// Domain names are vulnerable to XSS attacks if not properly validated. +/// Domain names can contain letters, numbers, dots, and hyphens. +/// Domain names maybe internationalized (IDN) and contain unicode characters. +/// +public class DomainNameValidatorAttribute : ValidationAttribute +{ + // RFC 1123 compliant domain name regex + // - Allows alphanumeric characters and hyphens + // - Cannot start or end with a hyphen + // - Each label (part between dots) must be 1-63 characters + // - Total length should not exceed 253 characters + // - Supports internationalized domain names (IDN) - which is why this regex includes unicode ranges + private static readonly Regex _domainNameRegex = new( + @"^(?:[a-zA-Z0-9\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF](?:[a-zA-Z0-9\-\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF]{0,61}[a-zA-Z0-9\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF])?\.)*[a-zA-Z0-9\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF](?:[a-zA-Z0-9\-\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF]{0,61}[a-zA-Z0-9\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF])?$", + RegexOptions.Compiled | RegexOptions.IgnoreCase + ); + + public DomainNameValidatorAttribute() + : base("The {0} field is not a valid domain name.") + { } + + public override bool IsValid(object? value) + { + if (value == null) + { + return true; // Use [Required] for null checks + } + + var domainName = value.ToString(); + + if (string.IsNullOrWhiteSpace(domainName)) + { + return false; + } + + // Reject if contains any whitespace (including leading/trailing spaces, tabs, newlines) + if (domainName.Any(char.IsWhiteSpace)) + { + return false; + } + + // Check length constraints + if (domainName.Length > 253) + { + return false; + } + + // Check for control characters or other dangerous characters + if (domainName.Any(c => char.IsControl(c) || c == '<' || c == '>' || c == '"' || c == '\'' || c == '&')) + { + return false; + } + + // Validate against domain name regex + return _domainNameRegex.IsMatch(domainName); + } +} diff --git a/src/Core/Utilities/KdfSettingsValidator.cs b/src/Core/Utilities/KdfSettingsValidator.cs index f89e8ddb66..e5690ad469 100644 --- a/src/Core/Utilities/KdfSettingsValidator.cs +++ b/src/Core/Utilities/KdfSettingsValidator.cs @@ -6,6 +6,7 @@ namespace Bit.Core.Utilities; public static class KdfSettingsValidator { + // PM-28143 - Remove below when fixing ticket public static IEnumerable Validate(KdfType kdfType, int kdfIterations, int? kdfMemory, int? kdfParallelism) { switch (kdfType) diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index b950e30d5d..f3330f0792 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -1,4 +1,5 @@ -using Microsoft.AspNetCore.Hosting; +using System.Globalization; +using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -8,7 +9,7 @@ namespace Bit.Core.Utilities; public static class LoggerFactoryExtensions { /// - /// + /// /// /// /// @@ -21,10 +22,12 @@ public static class LoggerFactoryExtensions return; } + IConfiguration loggingConfiguration; + // If they have begun using the new settings location, use that if (!string.IsNullOrEmpty(context.Configuration["Logging:PathFormat"])) { - logging.AddFile(context.Configuration.GetSection("Logging")); + loggingConfiguration = context.Configuration.GetSection("Logging"); } else { @@ -40,28 +43,35 @@ public static class LoggerFactoryExtensions var projectName = loggingOptions.ProjectName ?? context.HostingEnvironment.ApplicationName; + string pathFormat; + if (loggingOptions.LogRollBySizeLimit.HasValue) { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "log.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}.log"); - - logging.AddFile( - pathFormat: pathFormat, - fileSizeLimitBytes: loggingOptions.LogRollBySizeLimit.Value - ); } else { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "{Date}.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}_{{Date}}.log"); - - logging.AddFile( - pathFormat: pathFormat - ); } + + // We want to rely on Serilog using the configuration section to have customization of the log levels + // so we make a custom configuration source for them based on the legacy values and allow overrides from + // the new location. + loggingConfiguration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + {"PathFormat", pathFormat}, + {"FileSizeLimitBytes", loggingOptions.LogRollBySizeLimit?.ToString(CultureInfo.InvariantCulture)} + }) + .AddConfiguration(context.Configuration.GetSection("Logging")) + .Build(); } + + logging.AddFile(loggingConfiguration); }); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index fa2cfbb209..3a970d82bd 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Platform.Push; @@ -46,6 +47,7 @@ public class CipherService : ICipherService private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IApplicationCacheService _applicationCacheService; private readonly IFeatureService _featureService; + private readonly IPricingClient _pricingClient; public CipherService( ICipherRepository cipherRepository, @@ -65,7 +67,8 @@ public class CipherService : ICipherService IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery, IPolicyRequirementQuery policyRequirementQuery, IApplicationCacheService applicationCacheService, - IFeatureService featureService) + IFeatureService featureService, + IPricingClient pricingClient) { _cipherRepository = cipherRepository; _folderRepository = folderRepository; @@ -85,6 +88,7 @@ public class CipherService : ICipherService _policyRequirementQuery = policyRequirementQuery; _applicationCacheService = applicationCacheService; _featureService = featureService; + _pricingClient = pricingClient; } public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, @@ -871,7 +875,7 @@ public class CipherService : ICipherService if ((cipher.RevisionDate - lastKnownRevisionDate.Value).Duration() > TimeSpan.FromSeconds(1)) { throw new BadRequestException( - "The cipher you are updating is out of date. Please save your work, sync your vault, and try again." + "The item cannot be saved because it is out of date. To edit this item, first sync your vault, or log out and back in." ); } } @@ -943,10 +947,19 @@ public class CipherService : ICipherService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (cipher.OrganizationId.HasValue) diff --git a/src/Events/Events.csproj b/src/Events/Events.csproj index dcd66892ed..dc1df1d587 100644 --- a/src/Events/Events.csproj +++ b/src/Events/Events.csproj @@ -1,4 +1,5 @@  + bitwarden-Events diff --git a/src/Events/Program.cs b/src/Events/Program.cs index 1a00549005..78a3cfcdc0 100644 --- a/src/Events/Program.cs +++ b/src/Events/Program.cs @@ -8,7 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Events/appsettings.Production.json b/src/Events/appsettings.Production.json index 010f02f8cd..9a10621264 100644 --- a/src/Events/appsettings.Production.json +++ b/src/Events/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/EventsProcessor/EventsProcessor.csproj b/src/EventsProcessor/EventsProcessor.csproj index 2f1aeaef54..9c128aa606 100644 --- a/src/EventsProcessor/EventsProcessor.csproj +++ b/src/EventsProcessor/EventsProcessor.csproj @@ -1,4 +1,5 @@  + bitwarden-EventsProcessor diff --git a/src/EventsProcessor/Startup.cs b/src/EventsProcessor/Startup.cs index 239393a693..78f99058ad 100644 --- a/src/EventsProcessor/Startup.cs +++ b/src/EventsProcessor/Startup.cs @@ -1,7 +1,6 @@ using System.Globalization; using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; -using Microsoft.IdentityModel.Logging; namespace Bit.EventsProcessor; @@ -40,7 +39,6 @@ public class Startup public void Configure(IApplicationBuilder app) { - IdentityModelEventSource.ShowPII = true; // Add general security headers app.UseMiddleware(); app.UseRouting(); diff --git a/src/EventsProcessor/appsettings.Production.json b/src/EventsProcessor/appsettings.Production.json index 1cce4a9ed3..d57bf98b55 100644 --- a/src/EventsProcessor/appsettings.Production.json +++ b/src/EventsProcessor/appsettings.Production.json @@ -1,10 +1,8 @@ { "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Icons/Icons.csproj b/src/Icons/Icons.csproj index 97e9562183..adc1f6d557 100644 --- a/src/Icons/Icons.csproj +++ b/src/Icons/Icons.csproj @@ -1,4 +1,5 @@  + bitwarden-Icons @@ -9,7 +10,7 @@ - + diff --git a/src/Icons/Program.cs b/src/Icons/Program.cs index 80c1b5728e..5ca5723a52 100644 --- a/src/Icons/Program.cs +++ b/src/Icons/Program.cs @@ -8,6 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Icons/appsettings.Production.json b/src/Icons/appsettings.Production.json index 828e8c61cc..19d21f7260 100644 --- a/src/Icons/appsettings.Production.json +++ b/src/Icons/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Identity/Controllers/AccountsController.cs b/src/Identity/Controllers/AccountsController.cs index b7d4342c1b..e9807fb1fc 100644 --- a/src/Identity/Controllers/AccountsController.cs +++ b/src/Identity/Controllers/AccountsController.cs @@ -1,8 +1,4 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Diagnostics; -using System.Text; +using System.Text; using Bit.Core; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Api.Request.Accounts; @@ -42,7 +38,7 @@ public class AccountsController : Controller private readonly IFeatureService _featureService; private readonly IDataProtectorTokenFactory _registrationEmailVerificationTokenDataFactory; - private readonly byte[] _defaultKdfHmacKey = null; + private readonly byte[]? _defaultKdfHmacKey = null; private static readonly List _defaultKdfResults = [ // The first result (index 0) should always return the "normal" default. @@ -145,40 +141,55 @@ public class AccountsController : Controller [HttpPost("register/finish")] public async Task PostRegisterFinish([FromBody] RegisterFinishRequestModel model) { - var user = model.ToUser(); + User user = model.ToUser(); // Users will either have an emailed token or an email verification token - not both. - IdentityResult identityResult = null; + IdentityResult? identityResult = null; + + // PM-28143 - Just use the MasterPasswordAuthenticationData.MasterPasswordAuthenticationHash + string masterPasswordAuthenticationHash = model.MasterPasswordAuthentication?.MasterPasswordAuthenticationHash + ?? model.MasterPasswordHash!; switch (model.GetTokenType()) { case RegisterFinishTokenType.EmailVerification: - identityResult = - await _registerUserCommand.RegisterUserViaEmailVerificationToken(user, model.MasterPasswordHash, - model.EmailVerificationToken); - + identityResult = await _registerUserCommand.RegisterUserViaEmailVerificationToken( + user, + masterPasswordAuthenticationHash, + model.EmailVerificationToken!); return ProcessRegistrationResult(identityResult, user); + case RegisterFinishTokenType.OrganizationInvite: - identityResult = await _registerUserCommand.RegisterUserViaOrganizationInviteToken(user, model.MasterPasswordHash, - model.OrgInviteToken, model.OrganizationUserId); - + identityResult = await _registerUserCommand.RegisterUserViaOrganizationInviteToken( + user, + masterPasswordAuthenticationHash, + model.OrgInviteToken!, + model.OrganizationUserId); return ProcessRegistrationResult(identityResult, user); + case RegisterFinishTokenType.OrgSponsoredFreeFamilyPlan: - identityResult = await _registerUserCommand.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken(user, model.MasterPasswordHash, model.OrgSponsoredFreeFamilyPlanToken); - + identityResult = await _registerUserCommand.RegisterUserViaOrganizationSponsoredFreeFamilyPlanInviteToken( + user, + masterPasswordAuthenticationHash, + model.OrgSponsoredFreeFamilyPlanToken!); return ProcessRegistrationResult(identityResult, user); + case RegisterFinishTokenType.EmergencyAccessInvite: - Debug.Assert(model.AcceptEmergencyAccessId.HasValue); - identityResult = await _registerUserCommand.RegisterUserViaAcceptEmergencyAccessInviteToken(user, model.MasterPasswordHash, - model.AcceptEmergencyAccessInviteToken, model.AcceptEmergencyAccessId.Value); - + identityResult = await _registerUserCommand.RegisterUserViaAcceptEmergencyAccessInviteToken( + user, + masterPasswordAuthenticationHash, + model.AcceptEmergencyAccessInviteToken!, + (Guid)model.AcceptEmergencyAccessId!); return ProcessRegistrationResult(identityResult, user); + case RegisterFinishTokenType.ProviderInvite: - Debug.Assert(model.ProviderUserId.HasValue); - identityResult = await _registerUserCommand.RegisterUserViaProviderInviteToken(user, model.MasterPasswordHash, - model.ProviderInviteToken, model.ProviderUserId.Value); - + identityResult = await _registerUserCommand.RegisterUserViaProviderInviteToken( + user, + masterPasswordAuthenticationHash, + model.ProviderInviteToken!, + (Guid)model.ProviderUserId!); return ProcessRegistrationResult(identityResult, user); + default: throw new BadRequestException("Invalid registration finish request"); } diff --git a/src/Identity/Identity.csproj b/src/Identity/Identity.csproj index db49f8c856..f31d8c005e 100644 --- a/src/Identity/Identity.csproj +++ b/src/Identity/Identity.csproj @@ -1,4 +1,5 @@  + bitwarden-Identity diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index e07446d49f..289feebdb2 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -4,7 +4,6 @@ using System.Security.Claims; using Bit.Core; -using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Entities; @@ -233,56 +232,14 @@ public abstract class BaseRequestValidator where T : class private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // TODO: Clean up Feature Flag: Remove this if block: PM-28281 - if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) { - validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); - if (!validatorContext.SsoRequired) - { - return true; - } - - // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are - // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and - // review their new recovery token if desired. - // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. - // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been - // evaluated, and recovery will have been performed if requested. - // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect - // to /login. - if (validatorContext.TwoFactorRequired && - validatorContext.TwoFactorRecoveryRequested) - { - SetSsoResult(context, - new Dictionary - { - { - "ErrorModel", - new ErrorResponseModel( - "Two-factor recovery has been performed. SSO authentication is required.") - } - }); - return false; - } - - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return false; + return true; } - else - { - var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); - if (ssoValid) - { - return true; - } - SetValidationErrorResult(context, validatorContext); - return ssoValid; - } + SetValidationErrorResult(context, validatorContext); + return ssoValid; } /// @@ -521,9 +478,6 @@ public abstract class BaseRequestValidator where T : class [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] - protected abstract void SetSsoResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -540,41 +494,6 @@ public abstract class BaseRequestValidator where T : class protected abstract ClaimsPrincipal GetSubject(T context); - /// - /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are - /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. - /// If the GrantType is authorization_code or client_credentials we know the user is trying to login - /// using the SSO flow so they are allowed to continue. - /// - /// user trying to login - /// magic string identifying the grant type requested - /// true if sso required; false if not required or already in process - [Obsolete( - "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] - private async Task RequireSsoLoginAsync(User user, string grantType) - { - if (grantType == "authorization_code" || grantType == "client_credentials") - { - // Already using SSO to authenticate, or logging-in via api key to skip SSO requirement - // allow to authenticate successfully - return false; - } - - // Check if user belongs to any organization with an active SSO policy - var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) - ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired - : await PolicyService.AnyPoliciesApplicableToUserAsync( - user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - if (ssoRequired) - { - return true; - } - - // Default - SSO is not required - return false; - } - private async Task ResetFailedAuthDetailsAsync(User user) { // Early escape if db hit not necessary diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 38a4813ecd..2412c52308 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -194,17 +194,6 @@ public class CustomTokenRequestValidator : BaseRequestValidator customResponse) - { - Debug.Assert(context.Result is not null); - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Sso authentication required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(CustomTokenRequestValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index ea2c021f63..8bfddf24f3 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -152,14 +152,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs index 1f5bfba244..f38a4a880f 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendAccessConstants.cs @@ -69,13 +69,9 @@ public static class SendAccessConstants /// public const string EmailRequired = "email_required"; /// - /// Represents the error code indicating that an email address is invalid. - /// - public const string EmailInvalid = "email_invalid"; - /// /// Represents the status indicating that both email and OTP are required, and the OTP has been sent. /// - public const string EmailOtpSent = "email_and_otp_required_otp_sent"; + public const string EmailAndOtpRequired = "email_and_otp_required"; /// /// Represents the status indicating that both email and OTP are required, and the OTP is invalid. /// diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs index 34a7a6f6e7..02442d8c7e 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendEmailOtpRequestValidator.cs @@ -1,5 +1,4 @@ using System.Security.Claims; -using Bit.Core; using Bit.Core.Auth.Identity; using Bit.Core.Auth.Identity.TokenProviders; using Bit.Core.Services; @@ -11,7 +10,6 @@ using Duende.IdentityServer.Validation; namespace Bit.Identity.IdentityServer.RequestValidators.SendAccess; public class SendEmailOtpRequestValidator( - IFeatureService featureService, IOtpTokenProvider otpTokenProvider, IMailService mailService) : ISendAuthenticationMethodValidator { @@ -22,8 +20,7 @@ public class SendEmailOtpRequestValidator( private static readonly Dictionary _sendEmailOtpValidatorErrorDescriptions = new() { { SendAccessConstants.EmailOtpValidatorResults.EmailRequired, $"{SendAccessConstants.TokenRequest.Email} is required." }, - { SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent, "email otp sent." }, - { SendAccessConstants.EmailOtpValidatorResults.EmailInvalid, $"{SendAccessConstants.TokenRequest.Email} is invalid." }, + { SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired, $"{SendAccessConstants.TokenRequest.Email} and {SendAccessConstants.TokenRequest.Otp} are required." }, { SendAccessConstants.EmailOtpValidatorResults.EmailOtpInvalid, $"{SendAccessConstants.TokenRequest.Email} otp is invalid." }, }; @@ -33,17 +30,23 @@ public class SendEmailOtpRequestValidator( // get email var email = request.Get(SendAccessConstants.TokenRequest.Email); - // It is an invalid request if the email is missing which indicated bad shape. + // It is an invalid request if the email is missing. if (string.IsNullOrEmpty(email)) { // Request is the wrong shape and doesn't contain an email field. return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailRequired); } - // email must be in the list of emails in the EmailOtp array - if (!authMethod.Emails.Contains(email)) + /* + * This is somewhat contradictory to our process where a poor shape means invalid_request and invalid + * data is invalid_grant. + * In this case the shape is correct and the data is invalid but to protect against enumeration we treat incorrect emails + * as invalid requests. The response for a request with a correct email which needs an OTP and a request + * that has an invalid email need to be the same otherwise an attacker could enumerate until a valid email is found. + */ + if (!authMethod.emails.Contains(email, StringComparer.OrdinalIgnoreCase)) { - return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailInvalid); + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired); } // get otp from request @@ -62,21 +65,13 @@ public class SendEmailOtpRequestValidator( { return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.OtpGenerationFailed); } - if (featureService.IsEnabled(FeatureFlagKeys.MJMLBasedEmailTemplates)) - { - await mailService.SendSendEmailOtpEmailv2Async( - email, - token, - string.Format(SendAccessConstants.OtpEmail.Subject, token)); - } - else - { - await mailService.SendSendEmailOtpEmailAsync( - email, - token, - string.Format(SendAccessConstants.OtpEmail.Subject, token)); - } - return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent); + + await mailService.SendSendEmailOtpEmailAsync( + email, + token, + string.Format(SendAccessConstants.OtpEmail.Subject, token)); + + return BuildErrorResult(SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired); } // validate request otp @@ -100,7 +95,7 @@ public class SendEmailOtpRequestValidator( switch (error) { case SendAccessConstants.EmailOtpValidatorResults.EmailRequired: - case SendAccessConstants.EmailOtpValidatorResults.EmailOtpSent: + case SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired: return new GrantValidationResult(TokenRequestErrors.InvalidRequest, errorDescription: _sendEmailOtpValidatorErrorDescriptions[error], new Dictionary @@ -108,7 +103,6 @@ public class SendEmailOtpRequestValidator( { SendAccessConstants.SendAccessError, error } }); case SendAccessConstants.EmailOtpValidatorResults.EmailOtpInvalid: - case SendAccessConstants.EmailOtpValidatorResults.EmailInvalid: return new GrantValidationResult( TokenRequestErrors.InvalidGrant, errorDescription: _sendEmailOtpValidatorErrorDescriptions[error], diff --git a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendNeverAuthenticateRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendNeverAuthenticateRequestValidator.cs index 36e033360f..aabafaafd8 100644 --- a/src/Identity/IdentityServer/RequestValidators/SendAccess/SendNeverAuthenticateRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/SendAccess/SendNeverAuthenticateRequestValidator.cs @@ -38,7 +38,7 @@ public class SendNeverAuthenticateRequestValidator(GlobalSettings globalSettings break; case SendAccessConstants.EnumerationProtection.Email: var hasEmail = request.Get(SendAccessConstants.TokenRequest.Email) is not null; - errorType = hasEmail ? SendAccessConstants.EmailOtpValidatorResults.EmailInvalid + errorType = hasEmail ? SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired : SendAccessConstants.EmailOtpValidatorResults.EmailRequired; break; case SendAccessConstants.EnumerationProtection.Password: @@ -64,7 +64,7 @@ public class SendNeverAuthenticateRequestValidator(GlobalSettings globalSettings SendAccessConstants.EnumerationProtection.Guid => TokenRequestErrors.InvalidGrant, SendAccessConstants.PasswordValidatorResults.RequestPasswordIsRequired => TokenRequestErrors.InvalidGrant, SendAccessConstants.PasswordValidatorResults.RequestPasswordDoesNotMatch => TokenRequestErrors.InvalidRequest, - SendAccessConstants.EmailOtpValidatorResults.EmailInvalid => TokenRequestErrors.InvalidGrant, + SendAccessConstants.EmailOtpValidatorResults.EmailAndOtpRequired => TokenRequestErrors.InvalidRequest, SendAccessConstants.EmailOtpValidatorResults.EmailRequired => TokenRequestErrors.InvalidRequest, _ => TokenRequestErrors.InvalidGrant }; diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index e4cd60827e..1563831b81 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -142,14 +142,6 @@ public class WebAuthnGrantValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetValidationErrorResult instead.")] protected override void SetErrorResult(ExtensionGrantValidationContext context, Dictionary customResponse) { diff --git a/src/Identity/IdentityServer/UserDecryptionOptionsBuilder.cs b/src/Identity/IdentityServer/UserDecryptionOptionsBuilder.cs index 56b4bb0dcf..003e9a032e 100644 --- a/src/Identity/IdentityServer/UserDecryptionOptionsBuilder.cs +++ b/src/Identity/IdentityServer/UserDecryptionOptionsBuilder.cs @@ -64,8 +64,12 @@ public class UserDecryptionOptionsBuilder : IUserDecryptionOptionsBuilder { if (credential.GetPrfStatus() == WebAuthnPrfStatus.Enabled) { - _options.WebAuthnPrfOption = - new WebAuthnPrfDecryptionOption(credential.EncryptedPrivateKey, credential.EncryptedUserKey); + _options.WebAuthnPrfOption = new WebAuthnPrfDecryptionOption( + credential.EncryptedPrivateKey, + credential.EncryptedUserKey, + credential.CredentialId, + [] // Stored credentials currently lack Transports, just send an empty array for now + ); } return this; diff --git a/src/Identity/Program.cs b/src/Identity/Program.cs index 238ad8ce3a..ae284c86f2 100644 --- a/src/Identity/Program.cs +++ b/src/Identity/Program.cs @@ -15,7 +15,7 @@ public class Program { return Host .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Identity/Startup.cs b/src/Identity/Startup.cs index 9d5536fd10..bb1a974d82 100644 --- a/src/Identity/Startup.cs +++ b/src/Identity/Startup.cs @@ -14,8 +14,7 @@ using Bit.SharedWeb.Swagger; using Bit.SharedWeb.Utilities; using Duende.IdentityServer.Services; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.IdentityModel.Logging; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; namespace Bit.Identity; @@ -170,16 +169,14 @@ public class Startup public void Configure( IApplicationBuilder app, - IWebHostEnvironment env, + IWebHostEnvironment environment, GlobalSettings globalSettings, ILogger logger) { - IdentityModelEventSource.ShowPII = true; - // Add general security headers app.UseMiddleware(); - if (!env.IsDevelopment()) + if (!environment.IsDevelopment()) { var uri = new Uri(globalSettings.BaseServiceUri.Identity); app.Use(async (ctx, next) => @@ -196,7 +193,7 @@ public class Startup } // Default Middleware - app.UseDefaultMiddleware(env, globalSettings); + app.UseDefaultMiddleware(environment, globalSettings); if (!globalSettings.SelfHosted) { @@ -204,7 +201,7 @@ public class Startup app.UseMiddleware(); } - if (env.IsDevelopment()) + if (environment.IsDevelopment()) { app.UseSwagger(); app.UseDeveloperExceptionPage(); diff --git a/src/Identity/appsettings.Production.json b/src/Identity/appsettings.Production.json index 4897a7d8b1..14471b5fb6 100644 --- a/src/Identity/appsettings.Production.json +++ b/src/Identity/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index bd670347a9..ff2488d084 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -187,12 +187,12 @@ public class OrganizationUserRepository : Repository, IO return results.SingleOrDefault(); } } - public async Task<(OrganizationUserUserDetails? OrganizationUser, ICollection Collections)> GetDetailsByIdWithCollectionsAsync(Guid id) + public async Task<(OrganizationUserUserDetails? OrganizationUser, ICollection Collections)> GetDetailsByIdWithSharedCollectionsAsync(Guid id) { using (var connection = new SqlConnection(ConnectionString)) { var results = await connection.QueryMultipleAsync( - "[dbo].[OrganizationUserUserDetails_ReadWithCollectionsById]", + "[dbo].[OrganizationUserUserDetails_ReadWithSharedCollectionsById]", new { Id = id }, commandType: CommandType.StoredProcedure); @@ -202,7 +202,7 @@ public class OrganizationUserRepository : Repository, IO } } - public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeCollections) + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeSharedCollections) { using (var connection = new SqlConnection(ConnectionString)) { @@ -216,7 +216,7 @@ public class OrganizationUserRepository : Repository, IO var users = results.ToList(); - if (!includeCollections && !includeGroups) + if (!includeSharedCollections && !includeGroups) { return users; } @@ -231,10 +231,10 @@ public class OrganizationUserRepository : Repository, IO commandType: CommandType.StoredProcedure)).GroupBy(u => u.OrganizationUserId).ToList(); } - if (includeCollections) + if (includeSharedCollections) { userCollections = (await connection.QueryAsync( - "[dbo].[CollectionUser_ReadByOrganizationUserIds]", + "[dbo].[CollectionUser_ReadSharedCollectionsByOrganizationUserIds]", new { OrganizationUserIds = orgUserIds }, commandType: CommandType.StoredProcedure)).GroupBy(u => u.OrganizationUserId).ToList(); } @@ -267,7 +267,7 @@ public class OrganizationUserRepository : Repository, IO } } - public async Task> GetManyDetailsByOrganizationAsync_vNext(Guid organizationId, bool includeGroups, bool includeCollections) + public async Task> GetManyDetailsByOrganizationAsync_vNext(Guid organizationId, bool includeGroups, bool includeSharedCollections) { using (var connection = new SqlConnection(ConnectionString)) { @@ -278,7 +278,7 @@ public class OrganizationUserRepository : Repository, IO { OrganizationId = organizationId, IncludeGroups = includeGroups, - IncludeCollections = includeCollections + IncludeCollections = includeSharedCollections }, commandType: CommandType.StoredProcedure); @@ -297,7 +297,7 @@ public class OrganizationUserRepository : Repository, IO // Read collection associations (third result set, if requested) Dictionary>? userCollectionMap = null; - if (includeCollections) + if (includeSharedCollections) { var collectionUsers = await results.ReadAsync(); userCollectionMap = collectionUsers diff --git a/src/Infrastructure.Dapper/Auth/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.Dapper/Auth/Repositories/EmergencyAccessRepository.cs index 4d597ab045..f7dd17784e 100644 --- a/src/Infrastructure.Dapper/Auth/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.Dapper/Auth/Repositories/EmergencyAccessRepository.cs @@ -9,8 +9,6 @@ using Bit.Infrastructure.Dapper.Repositories; using Dapper; using Microsoft.Data.SqlClient; -#nullable enable - namespace Bit.Infrastructure.Dapper.Auth.Repositories; public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository @@ -152,4 +150,14 @@ public class EmergencyAccessRepository : Repository, IEme } }; } + + /// + public async Task DeleteManyAsync(ICollection emergencyAccessIds) + { + using var connection = new SqlConnection(ConnectionString); + await connection.ExecuteAsync( + "[dbo].[EmergencyAccess_DeleteManyById]", + new { EmergencyAccessIds = emergencyAccessIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.Dapper/Billing/Repositories/SubscriptionDiscountRepository.cs b/src/Infrastructure.Dapper/Billing/Repositories/SubscriptionDiscountRepository.cs new file mode 100644 index 0000000000..72fa8d7d4e --- /dev/null +++ b/src/Infrastructure.Dapper/Billing/Repositories/SubscriptionDiscountRepository.cs @@ -0,0 +1,39 @@ +using System.Data; +using Bit.Core.Billing.Subscriptions.Entities; +using Bit.Core.Billing.Subscriptions.Repositories; +using Bit.Core.Settings; +using Bit.Infrastructure.Dapper.Repositories; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.Billing.Repositories; + +public class SubscriptionDiscountRepository( + GlobalSettings globalSettings) + : Repository( + globalSettings.SqlServer.ConnectionString, + globalSettings.SqlServer.ReadOnlyConnectionString), ISubscriptionDiscountRepository +{ + public async Task> GetActiveDiscountsAsync() + { + using var sqlConnection = new SqlConnection(ReadOnlyConnectionString); + + var results = await sqlConnection.QueryAsync( + "[dbo].[SubscriptionDiscount_ReadActive]", + commandType: CommandType.StoredProcedure); + + return results.ToArray(); + } + + public async Task GetByStripeCouponIdAsync(string stripeCouponId) + { + using var sqlConnection = new SqlConnection(ReadOnlyConnectionString); + + var result = await sqlConnection.QueryFirstOrDefaultAsync( + "[dbo].[SubscriptionDiscount_ReadByStripeCouponId]", + new { StripeCouponId = stripeCouponId }, + commandType: CommandType.StoredProcedure); + + return result; + } +} diff --git a/src/Infrastructure.Dapper/DapperHelpers.cs b/src/Infrastructure.Dapper/DapperHelpers.cs index 9a119e1e32..4384a6f752 100644 --- a/src/Infrastructure.Dapper/DapperHelpers.cs +++ b/src/Infrastructure.Dapper/DapperHelpers.cs @@ -160,6 +160,21 @@ public static class DapperHelpers return ids.ToArrayTVP("GuidId"); } + public static DataTable ToTwoGuidIdArrayTVP(this IEnumerable<(Guid id1, Guid id2)> values) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[TwoGuidIdArray]"); + table.Columns.Add("Id1", typeof(Guid)); + table.Columns.Add("Id2", typeof(Guid)); + + foreach (var value in values) + { + table.Rows.Add(value.id1, value.id2); + } + + return table; + } + public static DataTable ToArrayTVP(this IEnumerable values, string columnName) { var table = new DataTable(); diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index dcb0dc1306..4055281352 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Organizations.Repositories; using Bit.Core.Billing.Providers.Repositories; +using Bit.Core.Billing.Subscriptions.Repositories; using Bit.Core.Dirt.Reports.Repositories; using Bit.Core.Dirt.Repositories; using Bit.Core.KeyManagement.Repositories; @@ -65,6 +66,7 @@ public static class DapperServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services diff --git a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs index a5a8cd0ee1..317e7ebbb3 100644 --- a/src/Infrastructure.Dapper/Repositories/BaseRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/BaseRepository.cs @@ -9,6 +9,7 @@ public abstract class BaseRepository static BaseRepository() { SqlMapper.AddTypeHandler(new DateTimeHandler()); + SqlMapper.AddTypeHandler(new JsonCollectionTypeHandler()); } public BaseRepository(string connectionString, string readOnlyConnectionString) diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 9985b41d56..2c733956c0 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -1,6 +1,7 @@ using System.Data; using System.Diagnostics.CodeAnalysis; using System.Text.Json; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; @@ -152,12 +153,12 @@ public class CollectionRepository : Repository, ICollectionRep } } - public async Task> GetManyByOrganizationIdWithPermissionsAsync(Guid organizationId, Guid userId, bool includeAccessRelationships) + public async Task> GetManySharedByOrganizationIdWithPermissionsAsync(Guid organizationId, Guid userId, bool includeAccessRelationships) { using (var connection = new SqlConnection(ConnectionString)) { var results = await connection.QueryMultipleAsync( - $"[{Schema}].[Collection_ReadByOrganizationIdWithPermissions]", + $"[{Schema}].[Collection_ReadSharedCollectionsByOrganizationIdWithPermissions]", new { OrganizationId = organizationId, UserId = userId, IncludeAccessRelationships = includeAccessRelationships }, commandType: CommandType.StoredProcedure); @@ -360,7 +361,45 @@ public class CollectionRepository : Repository, ICollectionRep } } - public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + { + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) + { + return; + } + + var organizationUserCollectionIds = organizationUserIds + .Select(ou => (ou, CoreHelpers.GenerateComb())) + .ToTwoGuidIdArrayTVP(); + + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + await using var transaction = connection.BeginTransaction(); + + try + { + await connection.ExecuteAsync( + "[dbo].[Collection_CreateDefaultCollections]", + new + { + OrganizationId = organizationId, + DefaultCollectionName = defaultCollectionName, + OrganizationUserCollectionIds = organizationUserCollectionIds + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + + await transaction.CommitAsync(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } + } + + public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -377,7 +416,8 @@ public class CollectionRepository : Repository, ICollectionRep var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); - var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + var (collections, collectionUsers) = + CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); if (!collectionUsers.Any() || !collections.Any()) { @@ -387,11 +427,11 @@ public class CollectionRepository : Repository, ICollectionRep await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); - transaction.Commit(); + await transaction.CommitAsync(); } catch { - transaction.Rollback(); + await transaction.RollbackAsync(); throw; } } @@ -421,40 +461,6 @@ public class CollectionRepository : Repository, ICollectionRep return organizationUserIds.ToHashSet(); } - private (List collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } - public class CollectionWithGroupsAndUsers : Collection { public CollectionWithGroupsAndUsers() { } diff --git a/src/Infrastructure.Dapper/Repositories/JsonCollectionTypeHandler.cs b/src/Infrastructure.Dapper/Repositories/JsonCollectionTypeHandler.cs new file mode 100644 index 0000000000..1f5455f7b7 --- /dev/null +++ b/src/Infrastructure.Dapper/Repositories/JsonCollectionTypeHandler.cs @@ -0,0 +1,31 @@ +using System.Data; +using System.Text.Json; +using Dapper; + +#nullable enable + +namespace Bit.Infrastructure.Dapper.Repositories; + +public class JsonCollectionTypeHandler : SqlMapper.TypeHandler?> +{ + public override void SetValue(IDbDataParameter parameter, ICollection? value) + { + parameter.Value = value == null ? (object)DBNull.Value : JsonSerializer.Serialize(value); + } + + public override ICollection? Parse(object value) + { + if (value == null || value is DBNull) + { + return null; + } + + var json = value.ToString(); + if (string.IsNullOrWhiteSpace(json)) + { + return null; + } + + return JsonSerializer.Deserialize>(json); + } +} diff --git a/src/Infrastructure.Dapper/Repositories/UserRepository.cs b/src/Infrastructure.Dapper/Repositories/UserRepository.cs index 920145f2f2..8d94ddae53 100644 --- a/src/Infrastructure.Dapper/Repositories/UserRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/UserRepository.cs @@ -404,6 +404,9 @@ public class UserRepository : Repository, IUserRepository public UpdateUserData SetKeyConnectorUserKey(Guid userId, string keyConnectorWrappedUserKey) { + var protectedKeyConnectorWrappedUserKey = string.Concat(Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(keyConnectorWrappedUserKey)); + return async (connection, transaction) => { var timestamp = DateTime.UtcNow; @@ -413,7 +416,7 @@ public class UserRepository : Repository, IUserRepository new { Id = userId, - Key = keyConnectorWrappedUserKey, + Key = protectedKeyConnectorWrappedUserKey, // Key Connector does not use KDF, so we set some defaults Kdf = KdfType.Argon2id, KdfIterations = AuthConstants.ARGON2_ITERATIONS.Default, @@ -431,6 +434,13 @@ public class UserRepository : Repository, IUserRepository public UpdateUserData SetMasterPassword(Guid userId, MasterPasswordUnlockData masterPasswordUnlockData, string serverSideHashedMasterPasswordAuthenticationHash, string? masterPasswordHint) { + var protectedMasterKeyWrappedUserKey = string.Concat(Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(masterPasswordUnlockData.MasterKeyWrappedUserKey)); + + var protectedServerSideHashedMasterPasswordAuthenticationHash = string.Concat( + Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(serverSideHashedMasterPasswordAuthenticationHash)); + return async (connection, transaction) => { var timestamp = DateTime.UtcNow; @@ -440,9 +450,9 @@ public class UserRepository : Repository, IUserRepository new { Id = userId, - MasterPassword = serverSideHashedMasterPasswordAuthenticationHash, + MasterPassword = protectedServerSideHashedMasterPasswordAuthenticationHash, MasterPasswordHint = masterPasswordHint, - Key = masterPasswordUnlockData.MasterKeyWrappedUserKey, + Key = protectedMasterKeyWrappedUserKey, Kdf = masterPasswordUnlockData.Kdf.KdfType, KdfIterations = masterPasswordUnlockData.Kdf.Iterations, KdfMemory = masterPasswordUnlockData.Kdf.Memory, diff --git a/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs b/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs index 81a94f0f7c..144e08021d 100644 --- a/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs +++ b/src/Infrastructure.Dapper/Tools/Repositories/SendRepository.cs @@ -1,6 +1,7 @@ #nullable enable using System.Data; +using Bit.Core; using Bit.Core.KeyManagement.UserKey; using Bit.Core.Settings; using Bit.Core.Tools.Entities; @@ -8,6 +9,7 @@ using Bit.Core.Tools.Repositories; using Bit.Infrastructure.Dapper.Repositories; using Bit.Infrastructure.Dapper.Tools.Helpers; using Dapper; +using Microsoft.AspNetCore.DataProtection; using Microsoft.Data.SqlClient; namespace Bit.Infrastructure.Dapper.Tools.Repositories; @@ -15,13 +17,24 @@ namespace Bit.Infrastructure.Dapper.Tools.Repositories; /// public class SendRepository : Repository, ISendRepository { - public SendRepository(GlobalSettings globalSettings) - : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + private readonly IDataProtector _dataProtector; + + public SendRepository(GlobalSettings globalSettings, IDataProtectionProvider dataProtectionProvider) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString, dataProtectionProvider) { } - public SendRepository(string connectionString, string readOnlyConnectionString) + public SendRepository(string connectionString, string readOnlyConnectionString, IDataProtectionProvider dataProtectionProvider) : base(connectionString, readOnlyConnectionString) - { } + { + _dataProtector = dataProtectionProvider.CreateProtector(Constants.DatabaseFieldProtectorPurpose); + } + + public override async Task GetByIdAsync(Guid id) + { + var send = await base.GetByIdAsync(id); + UnprotectData(send); + return send; + } /// public async Task> GetManyByUserIdAsync(Guid userId) @@ -33,7 +46,9 @@ public class SendRepository : Repository, ISendRepository new { UserId = userId }, commandType: CommandType.StoredProcedure); - return results.ToList(); + var sends = results.ToList(); + UnprotectData(sends); + return sends; } } @@ -47,15 +62,35 @@ public class SendRepository : Repository, ISendRepository new { DeletionDate = deletionDateBefore }, commandType: CommandType.StoredProcedure); - return results.ToList(); + var sends = results.ToList(); + UnprotectData(sends); + return sends; } } + public override async Task CreateAsync(Send send) + { + await ProtectDataAndSaveAsync(send, async () => await base.CreateAsync(send)); + return send; + } + + public override async Task ReplaceAsync(Send send) + { + await ProtectDataAndSaveAsync(send, async () => await base.ReplaceAsync(send)); + } + /// public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid userId, IEnumerable sends) { return async (connection, transaction) => { + // Protect all sends before bulk update + var sendsList = sends.ToList(); + foreach (var send in sendsList) + { + ProtectData(send); + } + // Create temp table var sqlCreateTemp = @" SELECT TOP 0 * @@ -71,7 +106,7 @@ public class SendRepository : Repository, ISendRepository using (var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction)) { bulkCopy.DestinationTableName = "#TempSend"; - var sendsTable = sends.ToDataTable(); + var sendsTable = sendsList.ToDataTable(); foreach (DataColumn col in sendsTable.Columns) { bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); @@ -101,6 +136,69 @@ public class SendRepository : Repository, ISendRepository cmd.Parameters.Add("@UserId", SqlDbType.UniqueIdentifier).Value = userId; cmd.ExecuteNonQuery(); } + + // Unprotect after save + foreach (var send in sendsList) + { + UnprotectData(send); + } }; } + + private async Task ProtectDataAndSaveAsync(Send send, Func saveTask) + { + if (send == null) + { + await saveTask(); + return; + } + + // Capture original value + var emails = send.Emails; + + // Protect value + ProtectData(send); + + // Save + await saveTask(); + + // Restore original value + send.Emails = emails; + } + + private void ProtectData(Send send) + { + if (!send.Emails?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + send.Emails = string.Concat(Constants.DatabaseFieldProtectedPrefix, + _dataProtector.Protect(send.Emails!)); + } + } + + private void UnprotectData(Send? send) + { + if (send == null) + { + return; + } + + if (send.Emails?.StartsWith(Constants.DatabaseFieldProtectedPrefix) ?? false) + { + send.Emails = _dataProtector.Unprotect( + send.Emails.Substring(Constants.DatabaseFieldProtectedPrefix.Length)); + } + } + + private void UnprotectData(IEnumerable sends) + { + if (sends == null) + { + return; + } + + foreach (var send in sends) + { + UnprotectData(send); + } + } } diff --git a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs index 48232ef484..ecf6d8e4e7 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/CipherRepository.cs @@ -248,7 +248,7 @@ public class CipherRepository : Repository, ICipherRepository new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, commandType: CommandType.StoredProcedure); - return results; + return DateTime.SpecifyKind(results, DateTimeKind.Utc); } } @@ -595,7 +595,7 @@ public class CipherRepository : Repository, ICipherRepository new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, commandType: CommandType.StoredProcedure); - return results; + return DateTime.SpecifyKind(results, DateTimeKind.Utc); } } @@ -608,7 +608,7 @@ public class CipherRepository : Repository, ICipherRepository new { Ids = ids.ToGuidIdArrayTVP(), UserId = userId }, commandType: CommandType.StoredProcedure); - return results; + return DateTime.SpecifyKind(results, DateTimeKind.Utc); } } @@ -621,7 +621,7 @@ public class CipherRepository : Repository, ICipherRepository new { Ids = ids.ToGuidIdArrayTVP(), OrganizationId = organizationId }, commandType: CommandType.StoredProcedure); - return results; + return DateTime.SpecifyKind(results, DateTimeKind.Utc); } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 88410facf5..93c8cd304c 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -325,7 +325,8 @@ public class OrganizationRepository : Repository Collections)> GetDetailsByIdWithCollectionsAsync(Guid id) + public async Task<(OrganizationUserUserDetails? OrganizationUser, ICollection Collections)> GetDetailsByIdWithSharedCollectionsAsync(Guid id) { var organizationUserUserDetails = await GetDetailsByIdAsync(id); using (var scope = ServiceScopeFactory.CreateScope()) @@ -359,7 +359,7 @@ public class OrganizationUserRepository : Repository new CollectionAccessSelection { @@ -438,7 +438,7 @@ public class OrganizationUserRepository : Repository> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeCollections) + public async Task> GetManyDetailsByOrganizationAsync(Guid organizationId, bool includeGroups, bool includeSharedCollections) { using (var scope = ServiceScopeFactory.CreateScope()) { @@ -448,7 +448,7 @@ public class OrganizationUserRepository : Repository g.OrganizationUserId).ToList(); } - if (includeCollections) + if (includeSharedCollections) { collections = (await (from cu in dbContext.CollectionUsers join ou in userIdEntities on cu.OrganizationUserId equals ou.Id join c in dbContext.Collections on cu.CollectionId equals c.Id - where c.Type != CollectionType.DefaultUserCollection + where c.Type == CollectionType.SharedCollection select cu).ToListAsync()) .GroupBy(c => c.OrganizationUserId).ToList(); } @@ -506,7 +506,7 @@ public class OrganizationUserRepository : Repository> GetManyDetailsByOrganizationAsync_vNext( - Guid organizationId, bool includeGroups, bool includeCollections) + Guid organizationId, bool includeGroups, bool includeSharedCollections) { using var scope = ServiceScopeFactory.CreateScope(); var dbContext = GetDatabaseContext(scope); @@ -541,7 +541,7 @@ public class OrganizationUserRepository : Repository gu.GroupId).ToList() : new List(), - Collections = includeCollections + Collections = includeSharedCollections ? ou.CollectionUsers .Where(cu => cu.Collection.Type == CollectionType.SharedCollection) .Select(cu => new CollectionAccessSelection diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserReadByClaimedOrganizationDomainsQuery.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserReadByClaimedOrganizationDomainsQuery.cs index d328691df0..643a2c684a 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserReadByClaimedOrganizationDomainsQuery.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/Queries/OrganizationUserReadByClaimedOrganizationDomainsQuery.cs @@ -1,4 +1,5 @@ using Bit.Core.Entities; +using Bit.Core.Enums; namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; @@ -16,6 +17,7 @@ public class OrganizationUserReadByClaimedOrganizationDomainsQuery : IQuery od.OrganizationId == _organizationId && od.VerifiedDate != null && diff --git a/src/Infrastructure.EntityFramework/Auth/Repositories/EmergencyAccessRepository.cs b/src/Infrastructure.EntityFramework/Auth/Repositories/EmergencyAccessRepository.cs index e1ea9bc03f..66cf1e55e6 100644 --- a/src/Infrastructure.EntityFramework/Auth/Repositories/EmergencyAccessRepository.cs +++ b/src/Infrastructure.EntityFramework/Auth/Repositories/EmergencyAccessRepository.cs @@ -10,8 +10,6 @@ using Microsoft.Data.SqlClient; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Auth.Repositories; public class EmergencyAccessRepository : Repository, IEmergencyAccessRepository @@ -146,4 +144,23 @@ public class EmergencyAccessRepository : Repository + public async Task DeleteManyAsync(ICollection emergencyAccessIds) + { + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + var entitiesToRemove = from ea in dbContext.EmergencyAccesses + where emergencyAccessIds.Contains(ea.Id) + select ea; + + var granteeIds = entitiesToRemove + .Where(ea => ea.Status == EmergencyAccessStatusType.Confirmed) + .Where(ea => ea.GranteeId.HasValue) + .Select(ea => ea.GranteeId!.Value) // .Value is safe here due to the Where above + .Distinct(); + + dbContext.EmergencyAccesses.RemoveRange(entitiesToRemove); + await dbContext.UserBumpManyAccountRevisionDatesAsync([.. granteeIds]); + await dbContext.SaveChangesAsync(); + } } diff --git a/src/Infrastructure.EntityFramework/Auth/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs b/src/Infrastructure.EntityFramework/Auth/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs index d666df76cf..7ddbcc346a 100644 --- a/src/Infrastructure.EntityFramework/Auth/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs +++ b/src/Infrastructure.EntityFramework/Auth/Repositories/Queries/EmergencyAccessDetailsViewQuery.cs @@ -2,8 +2,6 @@ using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Auth.Repositories.Queries; public class EmergencyAccessDetailsViewQuery : IQuery diff --git a/src/Infrastructure.EntityFramework/Billing/Configurations/SubscriptionDiscountEntityTypeConfiguration.cs b/src/Infrastructure.EntityFramework/Billing/Configurations/SubscriptionDiscountEntityTypeConfiguration.cs new file mode 100644 index 0000000000..c2c7eb86f2 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Billing/Configurations/SubscriptionDiscountEntityTypeConfiguration.cs @@ -0,0 +1,42 @@ +using System.Text.Json; +using Bit.Infrastructure.EntityFramework.Billing.Models; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.Metadata.Builders; + +namespace Bit.Infrastructure.EntityFramework.Billing.Configurations; + +public class SubscriptionDiscountEntityTypeConfiguration : IEntityTypeConfiguration +{ + public void Configure(EntityTypeBuilder builder) + { + builder + .Property(t => t.Id) + .ValueGeneratedNever(); + + builder + .HasIndex(sd => sd.StripeCouponId) + .IsUnique(); + + builder + .Property(sd => sd.StripeProductIds) + .HasConversion( + v => v == null ? null : JsonSerializer.Serialize(v, (JsonSerializerOptions?)null), + v => v == null ? null : JsonSerializer.Deserialize>(v, (JsonSerializerOptions?)null), + new ValueComparer?>( + (c1, c2) => (c1 == null && c2 == null) || (c1 != null && c2 != null && c1.SequenceEqual(c2)), + c => c == null ? 0 : c.Aggregate(0, (a, v) => HashCode.Combine(a, v.GetHashCode())), + c => c == null ? null : c.ToList())); + + builder + .Property(sd => sd.PercentOff) + .HasPrecision(5, 2); + + builder + .HasIndex(sd => new { sd.StartDate, sd.EndDate }) + .IsClustered(false) + .HasDatabaseName("IX_SubscriptionDiscount_DateRange"); + + builder.ToTable(nameof(SubscriptionDiscount)); + } +} diff --git a/src/Infrastructure.EntityFramework/Billing/Models/SubscriptionDiscount.cs b/src/Infrastructure.EntityFramework/Billing/Models/SubscriptionDiscount.cs new file mode 100644 index 0000000000..0cb5d9532a --- /dev/null +++ b/src/Infrastructure.EntityFramework/Billing/Models/SubscriptionDiscount.cs @@ -0,0 +1,18 @@ +#nullable enable + +using AutoMapper; + +namespace Bit.Infrastructure.EntityFramework.Billing.Models; + +// ReSharper disable once ClassWithVirtualMembersNeverInherited.Global +public class SubscriptionDiscount : Core.Billing.Subscriptions.Entities.SubscriptionDiscount +{ +} + +public class SubscriptionDiscountMapperProfile : Profile +{ + public SubscriptionDiscountMapperProfile() + { + CreateMap().ReverseMap(); + } +} diff --git a/src/Infrastructure.EntityFramework/Billing/Repositories/SubscriptionDiscountRepository.cs b/src/Infrastructure.EntityFramework/Billing/Repositories/SubscriptionDiscountRepository.cs new file mode 100644 index 0000000000..6ddcd65f27 --- /dev/null +++ b/src/Infrastructure.EntityFramework/Billing/Repositories/SubscriptionDiscountRepository.cs @@ -0,0 +1,51 @@ +using AutoMapper; +using Bit.Core.Billing.Subscriptions.Entities; +using Bit.Core.Billing.Subscriptions.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using EFSubscriptionDiscount = Bit.Infrastructure.EntityFramework.Billing.Models.SubscriptionDiscount; + +namespace Bit.Infrastructure.EntityFramework.Billing.Repositories; + +public class SubscriptionDiscountRepository( + IMapper mapper, + IServiceScopeFactory serviceScopeFactory) + : Repository( + serviceScopeFactory, + mapper, + context => context.SubscriptionDiscounts), ISubscriptionDiscountRepository +{ + public async Task> GetActiveDiscountsAsync() + { + using var serviceScope = ServiceScopeFactory.CreateScope(); + + var databaseContext = GetDatabaseContext(serviceScope); + + var query = + from subscriptionDiscount in databaseContext.SubscriptionDiscounts + where subscriptionDiscount.StartDate <= DateTime.UtcNow + && subscriptionDiscount.EndDate >= DateTime.UtcNow + select subscriptionDiscount; + + var results = await query.ToArrayAsync(); + + return Mapper.Map>(results); + } + + public async Task GetByStripeCouponIdAsync(string stripeCouponId) + { + using var serviceScope = ServiceScopeFactory.CreateScope(); + + var databaseContext = GetDatabaseContext(serviceScope); + + var query = + from subscriptionDiscount in databaseContext.SubscriptionDiscounts + where subscriptionDiscount.StripeCouponId == stripeCouponId + select subscriptionDiscount; + + var result = await query.FirstOrDefaultAsync(); + + return result == null ? null : Mapper.Map(result); + } +} diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index 320cb9436d..84a370b723 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Organizations.Repositories; using Bit.Core.Billing.Providers.Repositories; +using Bit.Core.Billing.Subscriptions.Repositories; using Bit.Core.Dirt.Reports.Repositories; using Bit.Core.Dirt.Repositories; using Bit.Core.Enums; @@ -102,6 +103,7 @@ public static class EntityFrameworkServiceCollectionExtensions services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 5aa156d1f8..141928b78a 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -1,11 +1,10 @@ using AutoMapper; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Utilities; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; @@ -304,7 +303,7 @@ public class CollectionRepository : Repository> GetManyByOrganizationIdWithPermissionsAsync( + public async Task> GetManySharedByOrganizationIdWithPermissionsAsync( Guid organizationId, Guid userId, bool includeAccessRelationships) { using (var scope = ServiceScopeFactory.CreateScope()) @@ -794,7 +793,7 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -808,15 +807,15 @@ public class CollectionRepository : Repository>(collections)); + await dbContext.CollectionUsers.AddRangeAsync(Mapper.Map>(collectionUsers)); await dbContext.SaveChangesAsync(); } @@ -844,37 +843,7 @@ public class CollectionRepository : Repository collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } + public Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) => + CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 7b67a63912..025fae802b 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -17,8 +17,6 @@ using Microsoft.EntityFrameworkCore.Storage.ValueConversion; using DP = Microsoft.AspNetCore.DataProtection; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Repositories; public class DatabaseContext : DbContext @@ -81,6 +79,7 @@ public class DatabaseContext : DbContext public DbSet WebAuthnCredentials { get; set; } public DbSet ProviderPlans { get; set; } public DbSet ProviderInvoiceItems { get; set; } + public DbSet SubscriptionDiscounts { get; set; } public DbSet Notifications { get; set; } public DbSet NotificationStatuses { get; set; } public DbSet ClientOrganizationMigrationRecords { get; set; } @@ -121,6 +120,7 @@ public class DatabaseContext : DbContext var eOrganizationDomain = builder.Entity(); var aWebAuthnCredential = builder.Entity(); var eOrganizationMemberBaseDetail = builder.Entity(); + var eSend = builder.Entity(); // Shadow property configurations go here @@ -150,6 +150,7 @@ public class DatabaseContext : DbContext var dataProtectionConverter = new DataProtectionConverter(dataProtector); eUser.Property(c => c.Key).HasConversion(dataProtectionConverter); eUser.Property(c => c.MasterPassword).HasConversion(dataProtectionConverter); + eSend.Property(c => c.Emails).HasConversion(dataProtectionConverter); if (Database.IsNpgsql()) { diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs index 40f2a79887..8c15cc17fa 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContextExtensions.cs @@ -1,14 +1,10 @@ -#nullable enable - -using System.Diagnostics; +using System.Diagnostics; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.Auth.Enums; using Bit.Core.Enums; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using Microsoft.EntityFrameworkCore; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Repositories; public static class DatabaseContextExtensions diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionAdminDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionAdminDetailsQuery.cs index 2b6e61d056..2ec671d20b 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionAdminDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/CollectionAdminDetailsQuery.cs @@ -62,7 +62,7 @@ public class CollectionAdminDetailsQuery : IQuery { baseCollectionQuery = baseCollectionQuery.Where(x => x.c.OrganizationId == _organizationId && - x.c.Type != CollectionType.DefaultUserCollection); + x.c.Type == CollectionType.SharedCollection); } else if (_collectionId.HasValue) { diff --git a/src/Notifications/Notifications.csproj b/src/Notifications/Notifications.csproj index 4d19f7faf9..76278fdea8 100644 --- a/src/Notifications/Notifications.csproj +++ b/src/Notifications/Notifications.csproj @@ -1,4 +1,5 @@  + bitwarden-Notifications diff --git a/src/Notifications/Program.cs b/src/Notifications/Program.cs index 2792391729..ec7ea67fda 100644 --- a/src/Notifications/Program.cs +++ b/src/Notifications/Program.cs @@ -8,7 +8,7 @@ public class Program { Host .CreateDefaultBuilder(args) - .ConfigureCustomAppConfiguration(args) + .UseBitwardenSdk() .ConfigureWebHostDefaults(webBuilder => { webBuilder.UseStartup(); diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index 65904ea698..3a4dc2d447 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -5,7 +5,6 @@ using Bit.Core.Utilities; using Bit.SharedWeb.Utilities; using Duende.IdentityModel; using Microsoft.AspNetCore.SignalR; -using Microsoft.IdentityModel.Logging; namespace Bit.Notifications; @@ -84,8 +83,6 @@ public class Startup IWebHostEnvironment env, GlobalSettings globalSettings) { - IdentityModelEventSource.ShowPII = true; - // Add general security headers app.UseMiddleware(); diff --git a/src/Notifications/appsettings.Production.json b/src/Notifications/appsettings.Production.json index 010f02f8cd..735c70e481 100644 --- a/src/Notifications/appsettings.Production.json +++ b/src/Notifications/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/SharedWeb/SharedWeb.csproj b/src/SharedWeb/SharedWeb.csproj index b6036845b0..17121913c1 100644 --- a/src/SharedWeb/SharedWeb.csproj +++ b/src/SharedWeb/SharedWeb.csproj @@ -13,7 +13,7 @@ - + diff --git a/src/SharedWeb/Swagger/ActionNameOperationFilter.cs b/src/SharedWeb/Swagger/ActionNameOperationFilter.cs index b76e8864ba..23602ca495 100644 --- a/src/SharedWeb/Swagger/ActionNameOperationFilter.cs +++ b/src/SharedWeb/Swagger/ActionNameOperationFilter.cs @@ -1,6 +1,5 @@ using System.Text.Json; -using Microsoft.OpenApi.Any; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -18,8 +17,9 @@ public class ActionNameOperationFilter : IOperationFilter if (!context.ApiDescription.ActionDescriptor.RouteValues.TryGetValue("action", out var action)) return; if (string.IsNullOrEmpty(action)) return; - operation.Extensions.Add("x-action-name", new OpenApiString(action)); + operation.Extensions ??= new Dictionary(); + operation.Extensions.Add("x-action-name", new JsonNodeExtension(action)); // We can't do case changes in the codegen templates, so we also add the snake_case version of the action name - operation.Extensions.Add("x-action-name-snake-case", new OpenApiString(JsonNamingPolicy.SnakeCaseLower.ConvertName(action))); + operation.Extensions.Add("x-action-name-snake-case", new JsonNodeExtension(JsonNamingPolicy.SnakeCaseLower.ConvertName(action))); } } diff --git a/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs b/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs index 3079a9171a..97cd04ac43 100644 --- a/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs +++ b/src/SharedWeb/Swagger/CheckDuplicateOperationIdsDocumentFilter.cs @@ -1,4 +1,4 @@ -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -15,19 +15,22 @@ public class CheckDuplicateOperationIdsDocumentFilter(bool printDuplicates = tru public void Apply(OpenApiDocument swaggerDoc, DocumentFilterContext context) { - var operationIdMap = new Dictionary>(); + var operationIdMap = new Dictionary>(); foreach (var (path, pathItem) in swaggerDoc.Paths) { - foreach (var operation in pathItem.Operations) + if (pathItem.Operations is null) continue; + + foreach (var (method, operation) in pathItem.Operations) { - if (!operationIdMap.TryGetValue(operation.Value.OperationId, out var list)) + var operationId = operation.OperationId ?? string.Empty; + if (!operationIdMap.TryGetValue(operationId, out var list)) { list = []; - operationIdMap[operation.Value.OperationId] = list; + operationIdMap[operationId] = list; } - list.Add((path, pathItem, operation.Key, operation.Value)); + list.Add((path, pathItem, method, operation)); } } @@ -57,11 +60,15 @@ public class CheckDuplicateOperationIdsDocumentFilter(bool printDuplicates = tru { Console.Write($" {method.ToString().ToUpper()} {path}"); + if (operation.Extensions is null) continue; - if (operation.Extensions.TryGetValue("x-source-file", out var sourceFile) && operation.Extensions.TryGetValue("x-source-line", out var sourceLine)) + if (operation.Extensions.TryGetValue("x-source-file", out var sourceFile) + && operation.Extensions.TryGetValue("x-source-line", out var sourceLine) + && sourceFile is JsonNodeExtension sourceFileNodeExt + && sourceLine is JsonNodeExtension sourceLineNodeExt) { - var sourceFileString = ((Microsoft.OpenApi.Any.OpenApiString)sourceFile).Value; - var sourceLineString = ((Microsoft.OpenApi.Any.OpenApiInteger)sourceLine).Value; + var sourceFileString = sourceFileNodeExt.Node.ToString(); + var sourceLineString = sourceLineNodeExt.Node.ToString(); Console.WriteLine($" {sourceFileString}:{sourceLineString}"); } diff --git a/src/SharedWeb/Swagger/EncryptedStringSchemaFilter.cs b/src/SharedWeb/Swagger/EncryptedStringSchemaFilter.cs index d26ae58e59..aade750205 100644 --- a/src/SharedWeb/Swagger/EncryptedStringSchemaFilter.cs +++ b/src/SharedWeb/Swagger/EncryptedStringSchemaFilter.cs @@ -2,7 +2,7 @@ using System.Text.Json; using Bit.Core.Utilities; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -13,7 +13,7 @@ namespace Bit.SharedWeb.Swagger; /// public class EncryptedStringSchemaFilter : ISchemaFilter { - public void Apply(OpenApiSchema schema, SchemaFilterContext context) + public void Apply(IOpenApiSchema schema, SchemaFilterContext context) { if (context.Type == null || schema.Properties == null) return; @@ -30,9 +30,9 @@ public class EncryptedStringSchemaFilter : ISchemaFilter // Convert prop.Name to camelCase for JSON schema property lookup var jsonPropName = JsonNamingPolicy.CamelCase.ConvertName(prop.Name); - if (schema.Properties.TryGetValue(jsonPropName, out var value)) + if (schema.Properties.TryGetValue(jsonPropName, out var value) && value is OpenApiSchema innerSchema) { - value.Format = "x-enc-string"; + innerSchema.Format = "x-enc-string"; } } } diff --git a/src/SharedWeb/Swagger/EnumSchemaFilter.cs b/src/SharedWeb/Swagger/EnumSchemaFilter.cs index 301fbfeca8..1b6fd6df57 100644 --- a/src/SharedWeb/Swagger/EnumSchemaFilter.cs +++ b/src/SharedWeb/Swagger/EnumSchemaFilter.cs @@ -1,5 +1,5 @@ -using Microsoft.OpenApi.Any; -using Microsoft.OpenApi.Models; +using System.Text.Json.Nodes; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -14,13 +14,15 @@ namespace Bit.SharedWeb.Swagger; /// public class EnumSchemaFilter : ISchemaFilter { - public void Apply(OpenApiSchema schema, SchemaFilterContext context) + public void Apply(IOpenApiSchema schema, SchemaFilterContext context) { - if (context.Type.IsEnum) + if (context.Type.IsEnum && schema is OpenApiSchema openApiSchema) { - var array = new OpenApiArray(); - array.AddRange(Enum.GetNames(context.Type).Select(n => new OpenApiString(n))); - schema.Extensions.Add("x-enum-varnames", array); + var array = new JsonArray(); + foreach (var name in Enum.GetNames(context.Type)) array.Add(name); + + openApiSchema.Extensions ??= new Dictionary(); + openApiSchema.Extensions.Add("x-enum-varnames", new JsonNodeExtension(array)); } } } diff --git a/src/SharedWeb/Swagger/GitCommitDocumentFilter.cs b/src/SharedWeb/Swagger/GitCommitDocumentFilter.cs index 86678722ce..fe51c5e588 100644 --- a/src/SharedWeb/Swagger/GitCommitDocumentFilter.cs +++ b/src/SharedWeb/Swagger/GitCommitDocumentFilter.cs @@ -1,7 +1,7 @@ #nullable enable using System.Diagnostics; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -16,7 +16,8 @@ public class GitCommitDocumentFilter : IDocumentFilter { if (!string.IsNullOrEmpty(GitCommit)) { - swaggerDoc.Extensions.Add("x-git-commit", new Microsoft.OpenApi.Any.OpenApiString(GitCommit)); + swaggerDoc.Extensions ??= new Dictionary(); + swaggerDoc.Extensions.Add("x-git-commit", new JsonNodeExtension(GitCommit)); } } diff --git a/src/SharedWeb/Swagger/SourceFileLineOperationFilter.cs b/src/SharedWeb/Swagger/SourceFileLineOperationFilter.cs index 68c0b5145a..b88744028f 100644 --- a/src/SharedWeb/Swagger/SourceFileLineOperationFilter.cs +++ b/src/SharedWeb/Swagger/SourceFileLineOperationFilter.cs @@ -4,8 +4,7 @@ using System.Reflection; using System.Reflection.Metadata; using System.Reflection.Metadata.Ecma335; using System.Runtime.CompilerServices; -using Microsoft.OpenApi.Any; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using Swashbuckle.AspNetCore.SwaggerGen; namespace Bit.SharedWeb.Swagger; @@ -24,8 +23,9 @@ public class SourceFileLineOperationFilter : IOperationFilter if (fileName != null && lineNumber > 0) { // Also add the information as extensions, so other tools can use it in the future - operation.Extensions.Add("x-source-file", new OpenApiString(fileName)); - operation.Extensions.Add("x-source-line", new OpenApiInteger(lineNumber)); + operation.Extensions ??= new Dictionary(); + operation.Extensions.Add("x-source-file", new JsonNodeExtension(fileName)); + operation.Extensions.Add("x-source-line", new JsonNodeExtension(lineNumber)); } } diff --git a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs index 1bb9cb6c7a..4139b59fa5 100644 --- a/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs +++ b/src/SharedWeb/Utilities/ServiceCollectionExtensions.cs @@ -22,6 +22,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Auth.Services; using Bit.Core.Auth.Services.Implementations; using Bit.Core.Auth.UserFeatures; +using Bit.Core.Auth.UserFeatures.EmergencyAccess; using Bit.Core.Auth.UserFeatures.PasswordValidation; using Bit.Core.Billing.Services; using Bit.Core.Billing.Services.Implementations; @@ -79,7 +80,7 @@ using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using Microsoft.OpenApi.Models; +using Microsoft.OpenApi; using StackExchange.Redis; using Swashbuckle.AspNetCore.SwaggerGen; using NoopRepos = Bit.Core.Repositories.Noop; @@ -471,11 +472,6 @@ public static class ServiceCollectionExtensions addAuthorization.Invoke(config); }); } - - if (environment.IsDevelopment()) - { - Microsoft.IdentityModel.Logging.IdentityModelEventSource.ShowPII = true; - } } public static void AddCustomDataProtectionServices( @@ -665,7 +661,6 @@ public static class ServiceCollectionExtensions Constants.BrowserExtensions.OperaId }; } - }); } @@ -852,19 +847,9 @@ public static class ServiceCollectionExtensions }); // Add security requirement - config.AddSecurityRequirement(new OpenApiSecurityRequirement + config.AddSecurityRequirement((document) => new OpenApiSecurityRequirement { - { - new OpenApiSecurityScheme - { - Reference = new OpenApiReference - { - Type = ReferenceType.SecurityScheme, - Id = serverId - }, - }, - [ApiScopes.ApiOrganization] - } + [new OpenApiSecuritySchemeReference(serverId, document)] = [ApiScopes.ApiOrganization] }); } } diff --git a/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql new file mode 100644 index 0000000000..4e671bd1e4 --- /dev/null +++ b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql @@ -0,0 +1,69 @@ +-- Creates default user collections for organization users +-- Filters out existing default collections at database level +CREATE PROCEDURE [dbo].[Collection_CreateDefaultCollections] + @OrganizationId UNIQUEIDENTIFIER, + @DefaultCollectionName VARCHAR(MAX), + @OrganizationUserCollectionIds AS [dbo].[TwoGuidIdArray] READONLY -- OrganizationUserId, CollectionId +AS +BEGIN + SET NOCOUNT ON + + DECLARE @Now DATETIME2(7) = GETUTCDATE() + + -- Filter to only users who don't have default collections + SELECT ids.Id1, ids.Id2 + INTO #FilteredIds + FROM @OrganizationUserCollectionIds ids + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionUser] cu + INNER JOIN [dbo].[Collection] c ON c.Id = cu.CollectionId + WHERE c.OrganizationId = @OrganizationId + AND c.[Type] = 1 -- CollectionType.DefaultUserCollection + AND cu.OrganizationUserId = ids.Id1 + ); + + -- Insert collections only for users who don't have default collections yet + INSERT INTO [dbo].[Collection] + ( + [Id], + [OrganizationId], + [Name], + [CreationDate], + [RevisionDate], + [Type], + [ExternalId], + [DefaultUserCollectionEmail] + ) + SELECT + ids.Id2, -- CollectionId + @OrganizationId, + @DefaultCollectionName, + @Now, + @Now, + 1, -- CollectionType.DefaultUserCollection + NULL, + NULL + FROM + #FilteredIds ids; + + -- Insert collection user mappings + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + ids.Id2, -- CollectionId + ids.Id1, -- OrganizationUserId + 0, -- ReadOnly = false + 0, -- HidePasswords = false + 1 -- Manage = true + FROM + #FilteredIds ids; + + DROP TABLE #FilteredIds; +END diff --git a/src/Sql/dbo/Auth/Stored Procedures/EmergencyAccess_DeleteManyById.sql b/src/Sql/dbo/Auth/Stored Procedures/EmergencyAccess_DeleteManyById.sql new file mode 100644 index 0000000000..75677ebbd9 --- /dev/null +++ b/src/Sql/dbo/Auth/Stored Procedures/EmergencyAccess_DeleteManyById.sql @@ -0,0 +1,41 @@ +CREATE PROCEDURE [dbo].[EmergencyAccess_DeleteManyById] + @EmergencyAccessIds [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + DECLARE @UserIds AS [GuidIdArray]; + DECLARE @BatchSize INT = 100 + + INSERT INTO @UserIds + SELECT DISTINCT + [GranteeId] + FROM + [dbo].[EmergencyAccess] EA + INNER JOIN + @EmergencyAccessIds EAI ON EAI.[Id] = EA.[Id] + WHERE + EA.[Status] = 2 -- 2 = Bit.Core.Auth.Enums.EmergencyAccessStatusType.Confirmed + AND + EA.[GranteeId] IS NOT NULL + + + -- Delete EmergencyAccess Records + WHILE @BatchSize > 0 + BEGIN + + DELETE TOP(@BatchSize) EA + FROM + [dbo].[EmergencyAccess] EA + INNER JOIN + @EmergencyAccessIds EAI ON EAI.[Id] = EA.[Id] + + SET @BatchSize = @@ROWCOUNT + + END + + -- Bump AccountRevisionDate for affected users after deletions + Exec [dbo].[User_BumpManyAccountRevisionDates] @UserIds + +END +GO diff --git a/src/Sql/dbo/Stored Procedures/CollectionUser_ReadSharedCollectionsByOrganizationUserIds.sql b/src/Sql/dbo/Stored Procedures/CollectionUser_ReadSharedCollectionsByOrganizationUserIds.sql new file mode 100644 index 0000000000..55cd477ab0 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/CollectionUser_ReadSharedCollectionsByOrganizationUserIds.sql @@ -0,0 +1,19 @@ +CREATE PROCEDURE [dbo].[CollectionUser_ReadSharedCollectionsByOrganizationUserIds] + @OrganizationUserIds [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + SELECT + CU.* + FROM + [dbo].[OrganizationUser] OU + INNER JOIN + [dbo].[CollectionUser] CU ON CU.[OrganizationUserId] = OU.[Id] + INNER JOIN + [dbo].[Collection] C ON CU.[CollectionId] = C.[Id] + INNER JOIN + @OrganizationUserIds OUI ON OUI.[Id] = OU.[Id] + WHERE + C.[Type] = 0 -- Only SharedCollection +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadWithSharedCollectionsById.sql b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadWithSharedCollectionsById.sql new file mode 100644 index 0000000000..96db73c80a --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUserUserDetails_ReadWithSharedCollectionsById.sql @@ -0,0 +1,23 @@ +CREATE PROCEDURE [dbo].[OrganizationUserUserDetails_ReadWithSharedCollectionsById] + @Id UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + + EXEC [OrganizationUserUserDetails_ReadById] @Id + + SELECT + CU.[CollectionId] Id, + CU.[ReadOnly], + CU.[HidePasswords], + CU.[Manage] + FROM + [dbo].[OrganizationUser] OU + INNER JOIN + [dbo].[CollectionUser] CU ON CU.[OrganizationUserId] = [OU].[Id] + INNER JOIN + [dbo].[Collection] C ON CU.[CollectionId] = C.[Id] + WHERE + [OrganizationUserId] = @Id + AND C.[Type] = 0 -- Only SharedCollection +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql index 64f3d81e08..4f781d2cc9 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql @@ -8,13 +8,14 @@ BEGIN SELECT * FROM [dbo].[OrganizationUserView] WHERE [OrganizationId] = @OrganizationId + AND [Status] != 0 -- Exclude invited users ), UserDomains AS ( SELECT U.[Id], U.[EmailDomain] FROM [dbo].[UserEmailDomainView] U WHERE EXISTS ( SELECT 1 - FROM [dbo].[OrganizationDomainView] OD + FROM [dbo].[OrganizationDomainView] OD WHERE OD.[OrganizationId] = @OrganizationId AND OD.[VerifiedDate] IS NOT NULL AND OD.[DomainName] = U.[EmailDomain] diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql index 583f548c8b..ee14c2c52a 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql @@ -6,7 +6,7 @@ BEGIN WITH CTE_User AS ( SELECT - U.*, + U.[Id], SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain FROM dbo.[UserView] U WHERE U.[Id] = @UserId @@ -19,4 +19,5 @@ BEGIN WHERE OD.[VerifiedDate] IS NOT NULL AND CU.EmailDomain = OD.[DomainName] AND O.[Enabled] = 1 + AND OU.[Status] != 0 -- Exclude invited users END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Create.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Create.sql new file mode 100644 index 0000000000..ad91dea631 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Create.sql @@ -0,0 +1,54 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_Create] + @Id UNIQUEIDENTIFIER OUTPUT, + @StripeCouponId VARCHAR(50), + @StripeProductIds NVARCHAR(MAX), + @PercentOff DECIMAL(5,2), + @AmountOff BIGINT, + @Currency VARCHAR(10), + @Duration VARCHAR(20), + @DurationInMonths INT, + @Name NVARCHAR(100), + @StartDate DATETIME2(7), + @EndDate DATETIME2(7), + @AudienceType INT, + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON + + INSERT INTO [dbo].[SubscriptionDiscount] + ( + [Id], + [StripeCouponId], + [StripeProductIds], + [PercentOff], + [AmountOff], + [Currency], + [Duration], + [DurationInMonths], + [Name], + [StartDate], + [EndDate], + [AudienceType], + [CreationDate], + [RevisionDate] + ) + VALUES + ( + @Id, + @StripeCouponId, + @StripeProductIds, + @PercentOff, + @AmountOff, + @Currency, + @Duration, + @DurationInMonths, + @Name, + @StartDate, + @EndDate, + @AudienceType, + @CreationDate, + @RevisionDate + ) +END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_DeleteById.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_DeleteById.sql new file mode 100644 index 0000000000..8d44a4b098 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_DeleteById.sql @@ -0,0 +1,12 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_DeleteById] + @Id UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + + DELETE + FROM + [dbo].[SubscriptionDiscount] + WHERE + [Id] = @Id +END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadActive.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadActive.sql new file mode 100644 index 0000000000..6247492789 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadActive.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_ReadActive] +AS +BEGIN + SET NOCOUNT ON + + SELECT + * + FROM + [dbo].[SubscriptionDiscountView] + WHERE + [StartDate] <= GETUTCDATE() + AND [EndDate] >= GETUTCDATE() +END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadById.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadById.sql new file mode 100644 index 0000000000..88943def64 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadById.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_ReadById] + @Id UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + + SELECT + * + FROM + [dbo].[SubscriptionDiscountView] + WHERE + [Id] = @Id +END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadByStripeCouponId.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadByStripeCouponId.sql new file mode 100644 index 0000000000..e935d614ed --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_ReadByStripeCouponId.sql @@ -0,0 +1,13 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_ReadByStripeCouponId] + @StripeCouponId VARCHAR(50) +AS +BEGIN + SET NOCOUNT ON + + SELECT + * + FROM + [dbo].[SubscriptionDiscountView] + WHERE + [StripeCouponId] = @StripeCouponId +END diff --git a/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Update.sql b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Update.sql new file mode 100644 index 0000000000..906f879a59 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/SubscriptionDiscount_Update.sql @@ -0,0 +1,38 @@ +CREATE PROCEDURE [dbo].[SubscriptionDiscount_Update] + @Id UNIQUEIDENTIFIER OUTPUT, + @StripeCouponId VARCHAR(50), + @StripeProductIds NVARCHAR(MAX), + @PercentOff DECIMAL(5,2), + @AmountOff BIGINT, + @Currency VARCHAR(10), + @Duration VARCHAR(20), + @DurationInMonths INT, + @Name NVARCHAR(100), + @StartDate DATETIME2(7), + @EndDate DATETIME2(7), + @AudienceType INT, + @CreationDate DATETIME2(7), + @RevisionDate DATETIME2(7) +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[SubscriptionDiscount] + SET + [StripeCouponId] = @StripeCouponId, + [StripeProductIds] = @StripeProductIds, + [PercentOff] = @PercentOff, + [AmountOff] = @AmountOff, + [Currency] = @Currency, + [Duration] = @Duration, + [DurationInMonths] = @DurationInMonths, + [Name] = @Name, + [StartDate] = @StartDate, + [EndDate] = @EndDate, + [AudienceType] = @AudienceType, + [CreationDate] = @CreationDate, + [RevisionDate] = @RevisionDate + WHERE + [Id] = @Id +END diff --git a/src/Sql/dbo/Tables/SubscriptionDiscount.sql b/src/Sql/dbo/Tables/SubscriptionDiscount.sql new file mode 100644 index 0000000000..36ca7c24eb --- /dev/null +++ b/src/Sql/dbo/Tables/SubscriptionDiscount.sql @@ -0,0 +1,22 @@ +CREATE TABLE [dbo].[SubscriptionDiscount] ( + [Id] UNIQUEIDENTIFIER NOT NULL, + [StripeCouponId] VARCHAR (50) NOT NULL, + [StripeProductIds] NVARCHAR (MAX) NULL, + [PercentOff] DECIMAL (5, 2) NULL, + [AmountOff] BIGINT NULL, + [Currency] VARCHAR (10) NULL, + [Duration] VARCHAR (20) NOT NULL, + [DurationInMonths] INT NULL, + [Name] NVARCHAR (100) NULL, + [StartDate] DATETIME2 (7) NOT NULL, + [EndDate] DATETIME2 (7) NOT NULL, + [AudienceType] INT NOT NULL CONSTRAINT [DF_SubscriptionDiscount_AudienceType] DEFAULT (0), + [CreationDate] DATETIME2 (7) NOT NULL, + [RevisionDate] DATETIME2 (7) NOT NULL, + CONSTRAINT [PK_SubscriptionDiscount] PRIMARY KEY CLUSTERED ([Id] ASC), + CONSTRAINT [IX_SubscriptionDiscount_StripeCouponId] UNIQUE NONCLUSTERED ([StripeCouponId] ASC) +); + +GO +CREATE NONCLUSTERED INDEX [IX_SubscriptionDiscount_DateRange] + ON [dbo].[SubscriptionDiscount]([StartDate] ASC, [EndDate] ASC); diff --git a/src/Sql/dbo/Tools/Tables/Send.sql b/src/Sql/dbo/Tools/Tables/Send.sql index 94311d6328..d7cea28383 100644 --- a/src/Sql/dbo/Tools/Tables/Send.sql +++ b/src/Sql/dbo/Tools/Tables/Send.sql @@ -1,18 +1,19 @@ -CREATE TABLE [dbo].[Send] ( +CREATE TABLE [dbo].[Send] +( [Id] UNIQUEIDENTIFIER NOT NULL, [UserId] UNIQUEIDENTIFIER NULL, [OrganizationId] UNIQUEIDENTIFIER NULL, [Type] TINYINT NOT NULL, [Data] VARCHAR(MAX) NOT NULL, - [Key] VARCHAR (MAX) NOT NULL, - [Password] NVARCHAR (300) NULL, - [Emails] NVARCHAR (4000) NULL, + [Key] VARCHAR(MAX) NOT NULL, + [Password] NVARCHAR(300) NULL, + [Emails] NVARCHAR(4000) NULL, [MaxAccessCount] INT NULL, [AccessCount] INT NOT NULL, - [CreationDate] DATETIME2 (7) NOT NULL, - [RevisionDate] DATETIME2 (7) NOT NULL, - [ExpirationDate] DATETIME2 (7) NULL, - [DeletionDate] DATETIME2 (7) NOT NULL, + [CreationDate] DATETIME2(7) NOT NULL, + [RevisionDate] DATETIME2(7) NOT NULL, + [ExpirationDate] DATETIME2(7) NULL, + [DeletionDate] DATETIME2(7) NOT NULL, [Disabled] BIT NOT NULL, [HideEmail] BIT NULL, [CipherId] UNIQUEIDENTIFIER NULL, @@ -26,9 +27,9 @@ GO CREATE NONCLUSTERED INDEX [IX_Send_UserId_OrganizationId] - ON [dbo].[Send]([UserId] ASC, [OrganizationId] ASC); + ON [dbo].[Send] ([UserId] ASC, [OrganizationId] ASC); GO CREATE NONCLUSTERED INDEX [IX_Send_DeletionDate] - ON [dbo].[Send]([DeletionDate] ASC); + ON [dbo].[Send] ([DeletionDate] ASC); diff --git a/src/Sql/dbo/Vault/Stored Procedures/Collections/Collection_ReadSharedCollectionsByOrganizationIdWithPermissions.sql b/src/Sql/dbo/Vault/Stored Procedures/Collections/Collection_ReadSharedCollectionsByOrganizationIdWithPermissions.sql new file mode 100644 index 0000000000..52120fe28a --- /dev/null +++ b/src/Sql/dbo/Vault/Stored Procedures/Collections/Collection_ReadSharedCollectionsByOrganizationIdWithPermissions.sql @@ -0,0 +1,86 @@ +CREATE PROCEDURE [dbo].[Collection_ReadSharedCollectionsByOrganizationIdWithPermissions] + @OrganizationId UNIQUEIDENTIFIER, + @UserId UNIQUEIDENTIFIER, + @IncludeAccessRelationships BIT +AS +BEGIN + SET NOCOUNT ON + + SELECT + C.*, + MIN(CASE + WHEN + COALESCE(CU.[ReadOnly], CG.[ReadOnly], 0) = 0 + THEN 0 + ELSE 1 + END) AS [ReadOnly], + MIN(CASE + WHEN + COALESCE(CU.[HidePasswords], CG.[HidePasswords], 0) = 0 + THEN 0 + ELSE 1 + END) AS [HidePasswords], + MAX(CASE + WHEN + COALESCE(CU.[Manage], CG.[Manage], 0) = 0 + THEN 0 + ELSE 1 + END) AS [Manage], + MAX(CASE + WHEN + CU.[CollectionId] IS NULL AND CG.[CollectionId] IS NULL + THEN 0 + ELSE 1 + END) AS [Assigned], + CASE + WHEN + -- No user or group has manage rights + NOT EXISTS( + SELECT 1 + FROM [dbo].[CollectionUser] CU2 + JOIN [dbo].[OrganizationUser] OU2 ON CU2.[OrganizationUserId] = OU2.[Id] + WHERE + CU2.[CollectionId] = C.[Id] AND + CU2.[Manage] = 1 + ) + AND NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionGroup] CG2 + WHERE + CG2.[CollectionId] = C.[Id] AND + CG2.[Manage] = 1 + ) + THEN 1 + ELSE 0 + END AS [Unmanaged] + FROM + [dbo].[CollectionView] C + LEFT JOIN + [dbo].[OrganizationUser] OU ON C.[OrganizationId] = OU.[OrganizationId] AND OU.[UserId] = @UserId + LEFT JOIN + [dbo].[CollectionUser] CU ON CU.[CollectionId] = C.[Id] AND CU.[OrganizationUserId] = [OU].[Id] + LEFT JOIN + [dbo].[GroupUser] GU ON CU.[CollectionId] IS NULL AND GU.[OrganizationUserId] = OU.[Id] + LEFT JOIN + [dbo].[Group] G ON G.[Id] = GU.[GroupId] + LEFT JOIN + [dbo].[CollectionGroup] CG ON CG.[CollectionId] = C.[Id] AND CG.[GroupId] = GU.[GroupId] + WHERE + C.[OrganizationId] = @OrganizationId AND + C.[Type] = 0 -- Only SharedCollection + GROUP BY + C.[Id], + C.[OrganizationId], + C.[Name], + C.[CreationDate], + C.[RevisionDate], + C.[ExternalId], + C.[DefaultUserCollectionEmail], + C.[Type] + + IF (@IncludeAccessRelationships = 1) + BEGIN + EXEC [dbo].[CollectionGroup_ReadByOrganizationId] @OrganizationId + EXEC [dbo].[CollectionUser_ReadByOrganizationId] @OrganizationId + END +END diff --git a/src/Sql/dbo/Views/SubscriptionDiscountView.sql b/src/Sql/dbo/Views/SubscriptionDiscountView.sql new file mode 100644 index 0000000000..71687d4ae8 --- /dev/null +++ b/src/Sql/dbo/Views/SubscriptionDiscountView.sql @@ -0,0 +1,5 @@ +CREATE VIEW [dbo].[SubscriptionDiscountView] +AS +SELECT * +FROM + [dbo].[SubscriptionDiscount] diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs index 71c6bf104c..f93f47a35a 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/GroupsControllerPerformanceTests.cs @@ -1,11 +1,15 @@ using System.Net; using System.Text; using System.Text.Json; +using AutoMapper; using Bit.Api.AdminConsole.Models.Request; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Api.Models.Request; +using Bit.Core.Entities; using Bit.Seeder.Recipes; +using Bit.Seeder.Services; +using Microsoft.AspNetCore.Identity; using Xunit; using Xunit.Abstractions; @@ -26,7 +30,10 @@ public class GroupsControllerPerformanceTests(ITestOutputHelper testOutputHelper var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = new NoOpManglerService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -34,8 +41,8 @@ public class GroupsControllerPerformanceTests(ITestOutputHelper testOutputHelper var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - var collectionIds = collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); - var groupIds = groupsSeeder.AddToOrganization(orgId, 1, orgUserIds, 0); + var collectionIds = collectionsSeeder.Seed(orgId, collectionCount, orgUserIds, 0); + var groupIds = groupsSeeder.Seed(orgId, 1, orgUserIds, 0); var groupId = groupIds.First(); diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPerformanceTests.cs index fc64930777..f7eb584b75 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPerformanceTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationUsersControllerPerformanceTests.cs @@ -1,13 +1,17 @@ using System.Net; using System.Text; using System.Text.Json; +using AutoMapper; using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Api.Models.Request; +using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Seeder.Recipes; +using Bit.Seeder.Services; +using Microsoft.AspNetCore.Identity; using Xunit; using Xunit.Abstractions; @@ -28,7 +32,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -37,8 +44,8 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); - groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + collectionsSeeder.Seed(orgId, 10, orgUserIds); + groupsSeeder.Seed(orgId, 5, orgUserIds); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -64,7 +71,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -72,8 +82,8 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: seats); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - collectionsSeeder.AddToOrganization(orgId, 10, orgUserIds); - groupsSeeder.AddToOrganization(orgId, 5, orgUserIds); + collectionsSeeder.Seed(orgId, 10, orgUserIds); + groupsSeeder.Seed(orgId, 5, orgUserIds); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -98,14 +108,17 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var groupsSeeder = new GroupsRecipe(db); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); var orgUserId = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).FirstOrDefault(); - groupsSeeder.AddToOrganization(orgId, 2, [orgUserId]); + groupsSeeder.Seed(orgId, 2, [orgUserId]); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -130,7 +143,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); @@ -163,7 +179,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed( @@ -211,7 +230,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); @@ -251,7 +273,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed( @@ -295,7 +320,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed( @@ -339,7 +367,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domainSeeder = new OrganizationDomainRecipe(db); var domain = OrganizationTestHelpers.GenerateRandomDomain(); @@ -350,7 +381,7 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO users: userCount, usersStatus: OrganizationUserStatusType.Confirmed); - domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + domainSeeder.Seed(orgId, domain); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -384,7 +415,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -392,8 +426,8 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - var collectionIds = collectionsSeeder.AddToOrganization(orgId, 3, orgUserIds, 0); - var groupIds = groupsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + var collectionIds = collectionsSeeder.Seed(orgId, 3, orgUserIds, 0); + var groupIds = groupsSeeder.Seed(orgId, 2, orgUserIds, 0); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -434,7 +468,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); @@ -471,7 +508,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domainSeeder = new OrganizationDomainRecipe(db); var domain = OrganizationTestHelpers.GenerateRandomDomain(); @@ -481,7 +521,7 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO users: 2, usersStatus: OrganizationUserStatusType.Confirmed); - domainSeeder.AddVerifiedDomainToOrganization(orgId, domain); + domainSeeder.Seed(orgId, domain); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -512,14 +552,17 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: 1); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - var collectionIds = collectionsSeeder.AddToOrganization(orgId, 2, orgUserIds, 0); + var collectionIds = collectionsSeeder.Seed(orgId, 2, orgUserIds, 0); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -560,7 +603,10 @@ public class OrganizationUsersControllerPerformanceTests(ITestOutputHelper testO var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = factory.GetService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var domain = OrganizationTestHelpers.GenerateRandomDomain(); var orgId = orgSeeder.Seed( diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs index 238a9a5d53..1bea3dd720 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/OrganizationsControllerPerformanceTests.cs @@ -1,14 +1,18 @@ using System.Net; using System.Text; using System.Text.Json; +using AutoMapper; using Bit.Api.AdminConsole.Models.Request.Organizations; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; using Bit.Core.AdminConsole.Models.Business.Tokenables; using Bit.Core.Billing.Enums; +using Bit.Core.Entities; using Bit.Core.Tokens; using Bit.Seeder.Recipes; +using Bit.Seeder.Services; +using Microsoft.AspNetCore.Identity; using Xunit; using Xunit.Abstractions; @@ -29,7 +33,10 @@ public class OrganizationsControllerPerformanceTests(ITestOutputHelper testOutpu var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = new NoOpManglerService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -37,8 +44,8 @@ public class OrganizationsControllerPerformanceTests(ITestOutputHelper testOutpu var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); - groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + collectionsSeeder.Seed(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.Seed(orgId, groupCount, orgUserIds, 0); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); @@ -77,7 +84,10 @@ public class OrganizationsControllerPerformanceTests(ITestOutputHelper testOutpu var client = factory.CreateClient(); var db = factory.GetDatabaseContext(); - var orgSeeder = new OrganizationWithUsersRecipe(db); + var mapper = factory.GetService(); + var passwordHasher = factory.GetService>(); + var manglerService = new NoOpManglerService(); + var orgSeeder = new OrganizationWithUsersRecipe(db, mapper, passwordHasher, manglerService); var collectionsSeeder = new CollectionsRecipe(db); var groupsSeeder = new GroupsRecipe(db); @@ -85,8 +95,8 @@ public class OrganizationsControllerPerformanceTests(ITestOutputHelper testOutpu var orgId = orgSeeder.Seed(name: "Org", domain: domain, users: userCount); var orgUserIds = db.OrganizationUsers.Where(ou => ou.OrganizationId == orgId).Select(ou => ou.Id).ToList(); - collectionsSeeder.AddToOrganization(orgId, collectionCount, orgUserIds, 0); - groupsSeeder.AddToOrganization(orgId, groupCount, orgUserIds, 0); + collectionsSeeder.Seed(orgId, collectionCount, orgUserIds, 0); + groupsSeeder.Seed(orgId, groupCount, orgUserIds, 0); await PerformanceTestHelpers.AuthenticateClientAsync(factory, client, $"owner@{domain}"); diff --git a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs index e4098ce9a9..d58538ae1c 100644 --- a/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Controllers/PoliciesControllerTests.cs @@ -150,8 +150,8 @@ public class PoliciesControllerTests : IClassFixture, IAs Enabled = true, Data = new Dictionary { - { "minComplexity", 10 }, - { "minLength", 12 }, + { "minComplexity", 4 }, + { "minLength", 128 }, { "requireUpper", true }, { "requireLower", false }, { "requireNumbers", true }, @@ -397,4 +397,48 @@ public class PoliciesControllerTests : IClassFixture, IAs // Assert Assert.Equal(HttpStatusCode.OK, response.StatusCode); } + + [Fact] + public async Task Put_MasterPasswordPolicy_ExcessiveMinLength_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", 129 } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_MasterPasswordPolicy_ExcessiveMinComplexity_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minComplexity", 5 } + } + }; + + // Act + var response = await _client.PutAsync($"/organizations/{_organization.Id}/policies/{policyType}", + JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs index 9f2512038e..e4bdbdb174 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/MembersControllerTests.cs @@ -264,4 +264,138 @@ public class MembersControllerTests : IClassFixture, IAsy new Permissions { CreateNewCollections = true, ManageScim = true, ManageGroups = true, ManageUsers = true }, orgUser.GetPermissions()); } + + [Fact] + public async Task Revoke_Member_Success() + { + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + + var response = await _client.PostAsync($"/public/members/{orgUser.Id}/revoke", null); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var updatedUser = await _factory.GetService() + .GetByIdAsync(orgUser.Id); + Assert.NotNull(updatedUser); + Assert.Equal(OrganizationUserStatusType.Revoked, updatedUser.Status); + } + + [Fact] + public async Task Revoke_AlreadyRevoked_ReturnsBadRequest() + { + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + + var revokeResponse = await _client.PostAsync($"/public/members/{orgUser.Id}/revoke", null); + Assert.Equal(HttpStatusCode.OK, revokeResponse.StatusCode); + + var response = await _client.PostAsync($"/public/members/{orgUser.Id}/revoke", null); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var error = await response.Content.ReadFromJsonAsync(); + Assert.Equal("Already revoked.", error?.Message); + } + + [Fact] + public async Task Revoke_NotFound_ReturnsNotFound() + { + var response = await _client.PostAsync($"/public/members/{Guid.NewGuid()}/revoke", null); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task Revoke_DifferentOrganization_ReturnsNotFound() + { + // Create a different organization + var ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(ownerEmail); + var (otherOrganization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); + + // Create a user in the other organization + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, otherOrganization.Id, OrganizationUserType.User); + + // Re-authenticate with the original organization + await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id); + + // Try to revoke the user from the other organization + var response = await _client.PostAsync($"/public/members/{orgUser.Id}/revoke", null); + + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task Restore_Member_Success() + { + // Invite a user to revoke + var email = $"integration-test{Guid.NewGuid()}@example.com"; + var inviteRequest = new MemberCreateRequestModel + { + Email = email, + Type = OrganizationUserType.User, + }; + + var inviteResponse = await _client.PostAsync("/public/members", JsonContent.Create(inviteRequest)); + Assert.Equal(HttpStatusCode.OK, inviteResponse.StatusCode); + var invitedMember = await inviteResponse.Content.ReadFromJsonAsync(); + Assert.NotNull(invitedMember); + + // Revoke the invited user + var revokeResponse = await _client.PostAsync($"/public/members/{invitedMember.Id}/revoke", null); + Assert.Equal(HttpStatusCode.OK, revokeResponse.StatusCode); + + // Restore the user + var response = await _client.PostAsync($"/public/members/{invitedMember.Id}/restore", null); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Verify user is restored to Invited state + var updatedUser = await _factory.GetService() + .GetByIdAsync(invitedMember.Id); + Assert.NotNull(updatedUser); + Assert.Equal(OrganizationUserStatusType.Invited, updatedUser.Status); + } + + [Fact] + public async Task Restore_AlreadyActive_ReturnsBadRequest() + { + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, _organization.Id, OrganizationUserType.User); + + var response = await _client.PostAsync($"/public/members/{orgUser.Id}/restore", null); + + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + var error = await response.Content.ReadFromJsonAsync(); + Assert.Equal("Already active.", error?.Message); + } + + [Fact] + public async Task Restore_NotFound_ReturnsNotFound() + { + var response = await _client.PostAsync($"/public/members/{Guid.NewGuid()}/restore", null); + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } + + [Fact] + public async Task Restore_DifferentOrganization_ReturnsNotFound() + { + // Create a different organization + var ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(ownerEmail); + var (otherOrganization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually, + ownerEmail: ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); + + // Create a user in the other organization + var (_, orgUser) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, otherOrganization.Id, OrganizationUserType.User); + + // Re-authenticate with the original organization + await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id); + + // Try to restore the user from the other organization + var response = await _client.PostAsync($"/public/members/{orgUser.Id}/restore", null); + + Assert.Equal(HttpStatusCode.NotFound, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs index 6144d7eebb..a669bdd93c 100644 --- a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -61,7 +61,8 @@ public class PoliciesControllerTests : IClassFixture, IAs Enabled = true, Data = new Dictionary { - { "minComplexity", 15}, + { "minComplexity", 4}, + { "minLength", 128 }, { "requireLower", true} } }; @@ -78,7 +79,8 @@ public class PoliciesControllerTests : IClassFixture, IAs Assert.IsType(result.Id); Assert.NotEqual(default, result.Id); Assert.NotNull(result.Data); - Assert.Equal(15, ((JsonElement)result.Data["minComplexity"]).GetInt32()); + Assert.Equal(4, ((JsonElement)result.Data["minComplexity"]).GetInt32()); + Assert.Equal(128, ((JsonElement)result.Data["minLength"]).GetInt32()); Assert.True(((JsonElement)result.Data["requireLower"]).GetBoolean()); // Assert against the database values @@ -94,7 +96,7 @@ public class PoliciesControllerTests : IClassFixture, IAs Assert.NotNull(policy.Data); var data = policy.GetDataModel(); - var expectedData = new MasterPasswordPolicyData { MinComplexity = 15, RequireLower = true }; + var expectedData = new MasterPasswordPolicyData { MinComplexity = 4, MinLength = 128, RequireLower = true }; AssertHelper.AssertPropertyEqual(expectedData, data); } @@ -242,4 +244,46 @@ public class PoliciesControllerTests : IClassFixture, IAs // Assert Assert.Equal(HttpStatusCode.OK, response.StatusCode); } + + [Fact] + public async Task Put_MasterPasswordPolicy_ExcessiveMinLength_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minLength", 129 } + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } + + [Fact] + public async Task Put_MasterPasswordPolicy_ExcessiveMinComplexity_ReturnsBadRequest() + { + // Arrange + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minComplexity", 5 } + } + }; + + // Act + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert + Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); + } } diff --git a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs index d055418f3a..9860775e31 100644 --- a/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs +++ b/test/Api.IntegrationTest/Controllers/AccountsControllerTest.cs @@ -3,7 +3,6 @@ using System.Text.Json; using Bit.Api.Auth.Models.Request.Accounts; using Bit.Api.IntegrationTest.Factories; using Bit.Api.IntegrationTest.Helpers; -using Bit.Api.KeyManagement.Models.Requests; using Bit.Api.Models.Response; using Bit.Core; using Bit.Core.Auth.Entities; @@ -12,6 +11,7 @@ using Bit.Core.Auth.Models.Data; using Bit.Core.Auth.Repositories; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Core.KeyManagement.Repositories; using Bit.Core.Models.Data; using Bit.Core.Platform.Push; @@ -378,7 +378,7 @@ public class AccountsControllerTest : IClassFixture, IAsy Assert.Equal(HttpStatusCode.BadRequest, response.StatusCode); var content = await response.Content.ReadAsStringAsync(); - Assert.Contains("KDF settings are invalid", content); + Assert.Contains("The model state is invalid", content); } [Fact] diff --git a/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs b/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs index a729abb849..3551ed4efa 100644 --- a/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs +++ b/test/Api.IntegrationTest/Controllers/Public/CollectionsControllerTests.cs @@ -6,6 +6,7 @@ using Bit.Api.Models.Public.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Enums; +using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Platform.Push; @@ -114,4 +115,64 @@ public class CollectionsControllerTests : IClassFixture, Assert.NotEmpty(result.Item2.Groups); Assert.NotEmpty(result.Item2.Users); } + + [Fact] + public async Task List_ExcludesDefaultUserCollections_IncludesGroupsAndUsers() + { + // Arrange + var collectionRepository = _factory.GetService(); + var groupRepository = _factory.GetService(); + + var defaultCollection = new Collection + { + OrganizationId = _organization.Id, + Name = "My Items", + Type = CollectionType.DefaultUserCollection + }; + await collectionRepository.CreateAsync(defaultCollection, null, null); + + var group = await groupRepository.CreateAsync(new Group + { + OrganizationId = _organization.Id, + Name = "Test Group", + ExternalId = $"test-group-{Guid.NewGuid()}", + }); + + var (_, user) = await OrganizationTestHelpers.CreateNewUserWithAccountAsync( + _factory, + _organization.Id, + OrganizationUserType.User); + + var sharedCollection = await OrganizationTestHelpers.CreateCollectionAsync( + _factory, + _organization.Id, + "Shared Collection with Access", + externalId: "shared-collection-with-access", + groups: + [ + new CollectionAccessSelection { Id = group.Id, ReadOnly = false, HidePasswords = false, Manage = true } + ], + users: + [ + new CollectionAccessSelection { Id = user.Id, ReadOnly = true, HidePasswords = true, Manage = false } + ]); + + // Act + var response = await _client.GetFromJsonAsync>("public/collections"); + + // Assert + Assert.NotNull(response); + + Assert.DoesNotContain(response.Data, c => c.Id == defaultCollection.Id); + + var collectionResponse = response.Data.First(c => c.Id == sharedCollection.Id); + Assert.NotNull(collectionResponse.Groups); + Assert.Single(collectionResponse.Groups); + + var groupResponse = collectionResponse.Groups.First(); + Assert.Equal(group.Id, groupResponse.Id); + Assert.False(groupResponse.ReadOnly); + Assert.False(groupResponse.HidePasswords); + Assert.True(groupResponse.Manage); + } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs index 43f0123a3f..f9b50e736d 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationUsersControllerTests.cs @@ -14,7 +14,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Utilities.v2.Results; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Repositories; @@ -30,6 +29,7 @@ using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -121,10 +121,37 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] - public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + public async Task Accept_WhenOrganizationUserNotFound_ThrowsNotFoundException( + Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, User user, + SutProvider sutProvider) { sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns((OrganizationUser)null); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); + } + + [Theory] + [BitAutoData] + public async Task Accept_WhenOrganizationIdMismatch_ThrowsNotFoundException( + Guid orgId, Guid orgUserId, OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, + SutProvider sutProvider) + { + organizationUser.OrganizationId = Guid.NewGuid(); // Different org ID + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); + + await Assert.ThrowsAsync(() => sutProvider.Sut.Accept(orgId, orgUserId, model)); + } + + [Theory] + [BitAutoData] + public async Task Accept_NoMasterPasswordReset(Guid orgId, Guid orgUserId, + OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, SutProvider sutProvider) + { + organizationUser.OrganizationId = orgId; + sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); await sutProvider.Sut.Accept(orgId, orgUserId, model); @@ -137,23 +164,22 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] public async Task Accept_WhenOrganizationUsePoliciesIsEnabledAndResetPolicyIsEnabled_ShouldHandleResetPassword(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, + SutProvider sutProvider) { // Arrange + organizationUser.OrganizationId = orgId; var applicationCacheService = sutProvider.GetDependency(); applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = true }); - var policy = new Policy - { - Enabled = true, - Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), - }; + policy.Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }); var userService = sutProvider.GetDependency(); userService.GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); - - var policyRepository = sutProvider.GetDependency(); - policyRepository.GetByOrganizationIdTypeAsync(orgId, + var policyQuery = sutProvider.GetDependency(); + policyQuery.RunAsync(orgId, PolicyType.ResetPassword).Returns(policy); // Act @@ -167,29 +193,29 @@ public class OrganizationUsersControllerTests await userService.Received(1).GetUserByPrincipalAsync(default); await applicationCacheService.Received(1).GetOrganizationAbilityAsync(orgId); - await policyRepository.Received(1).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + await policyQuery.Received(1).RunAsync(orgId, PolicyType.ResetPassword); } [Theory] [BitAutoData] public async Task Accept_WhenOrganizationUsePoliciesIsDisabled_ShouldNotHandleResetPassword(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, + SutProvider sutProvider) { // Arrange + organizationUser.OrganizationId = orgId; var applicationCacheService = sutProvider.GetDependency(); applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = false }); - var policy = new Policy - { - Enabled = true, - Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }), - }; + policy.Data = CoreHelpers.ClassToJsonData(new ResetPasswordDataModel { AutoEnrollEnabled = true, }); var userService = sutProvider.GetDependency(); userService.GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); - var policyRepository = sutProvider.GetDependency(); - policyRepository.GetByOrganizationIdTypeAsync(orgId, + var policyQuery = sutProvider.GetDependency(); + policyQuery.RunAsync(orgId, PolicyType.ResetPassword).Returns(policy); // Act @@ -202,7 +228,7 @@ public class OrganizationUsersControllerTests await sutProvider.GetDependency().Received(0) .UpdateUserResetPasswordEnrollmentAsync(orgId, user.Id, model.ResetPasswordKey, user.Id); - await policyRepository.Received(0).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + await policyQuery.Received(0).RunAsync(orgId, PolicyType.ResetPassword); await applicationCacheService.Received(1).GetOrganizationAbilityAsync(orgId); } @@ -261,7 +287,7 @@ public class OrganizationUsersControllerTests .Returns(true); sutProvider.GetDependency() - .GetDetailsByIdWithCollectionsAsync(organizationUser.Id) + .GetDetailsByIdWithSharedCollectionsAsync(organizationUser.Id) .Returns((organizationUser, collections)); sutProvider.GetDependency() @@ -365,13 +391,15 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] public async Task Accept_WhenOrganizationUsePoliciesIsEnabledAndResetPolicyIsEnabled_WithPolicyRequirementsEnabled_ShouldHandleResetPassword(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, SutProvider sutProvider) { // Arrange + organizationUser.OrganizationId = orgId; var applicationCacheService = sutProvider.GetDependency(); applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = true }); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); var policy = new Policy { @@ -383,7 +411,7 @@ public class OrganizationUsersControllerTests var policyRequirementQuery = sutProvider.GetDependency(); - var policyRepository = sutProvider.GetDependency(); + var policyQuery = sutProvider.GetDependency(); var policyRequirement = new ResetPasswordPolicyRequirement { AutoEnrollOrganizations = [orgId] }; @@ -400,7 +428,7 @@ public class OrganizationUsersControllerTests await userService.Received(1).GetUserByPrincipalAsync(default); await applicationCacheService.Received(0).GetOrganizationAbilityAsync(orgId); - await policyRepository.Received(0).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + await policyQuery.Received(0).RunAsync(orgId, PolicyType.ResetPassword); await policyRequirementQuery.Received(1).GetAsync(user.Id); Assert.True(policyRequirement.AutoEnrollEnabled(orgId)); } @@ -408,14 +436,16 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] public async Task Accept_WithInvalidModelResetPasswordKey_WithPolicyRequirementsEnabled_ThrowsBadRequestException(Guid orgId, Guid orgUserId, - OrganizationUserAcceptRequestModel model, User user, SutProvider sutProvider) + OrganizationUserAcceptRequestModel model, User user, OrganizationUser organizationUser, SutProvider sutProvider) { // Arrange model.ResetPasswordKey = " "; + organizationUser.OrganizationId = orgId; var applicationCacheService = sutProvider.GetDependency(); applicationCacheService.GetOrganizationAbilityAsync(orgId).Returns(new OrganizationAbility { UsePolicies = true }); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + sutProvider.GetDependency().GetByIdAsync(orgUserId).Returns(organizationUser); var policy = new Policy { @@ -425,7 +455,7 @@ public class OrganizationUsersControllerTests var userService = sutProvider.GetDependency(); userService.GetUserByPrincipalAsync(default).ReturnsForAnyArgs(user); - var policyRepository = sutProvider.GetDependency(); + var policyQuery = sutProvider.GetDependency(); var policyRequirementQuery = sutProvider.GetDependency(); @@ -445,7 +475,7 @@ public class OrganizationUsersControllerTests await userService.Received(1).GetUserByPrincipalAsync(default); await applicationCacheService.Received(0).GetOrganizationAbilityAsync(orgId); - await policyRepository.Received(0).GetByOrganizationIdTypeAsync(orgId, PolicyType.ResetPassword); + await policyQuery.Received(0).RunAsync(orgId, PolicyType.ResetPassword); await policyRequirementQuery.Received(1).GetAsync(user.Id); Assert.Equal("Master Password reset is required, but not provided.", exception.Message); @@ -734,7 +764,7 @@ public class OrganizationUsersControllerTests [Theory] [BitAutoData] - public async Task BulkReinvite_WhenFeatureFlagEnabled_UsesBulkResendOrganizationInvitesCommand( + public async Task BulkReinvite_UsesBulkResendOrganizationInvitesCommand( Guid organizationId, OrganizationUserBulkRequestModel bulkRequestModel, List organizationUsers, @@ -744,9 +774,6 @@ public class OrganizationUsersControllerTests // Arrange sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) - .Returns(true); var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); sutProvider.GetDependency() @@ -763,36 +790,4 @@ public class OrganizationUsersControllerTests .Received(1) .BulkResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); } - - [Theory] - [BitAutoData] - public async Task BulkReinvite_WhenFeatureFlagDisabled_UsesLegacyOrganizationService( - Guid organizationId, - OrganizationUserBulkRequestModel bulkRequestModel, - List organizationUsers, - Guid userId, - SutProvider sutProvider) - { - // Arrange - sutProvider.GetDependency().ManageUsers(organizationId).Returns(true); - sutProvider.GetDependency().GetProperUserId(Arg.Any()).Returns(userId); - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.IncreaseBulkReinviteLimitForCloud) - .Returns(false); - - var expectedResults = organizationUsers.Select(u => Tuple.Create(u, "")).ToList(); - sutProvider.GetDependency() - .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids) - .Returns(expectedResults); - - // Act - var response = await sutProvider.Sut.BulkReinvite(organizationId, bulkRequestModel); - - // Assert - Assert.Equal(organizationUsers.Count, response.Data.Count()); - - await sutProvider.GetDependency() - .Received(1) - .ResendInvitesAsync(organizationId, userId, bulkRequestModel.Ids); - } } diff --git a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs index d87f035a13..cc09e9e0a0 100644 --- a/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs +++ b/test/Api.Test/AdminConsole/Controllers/OrganizationsControllerTests.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Enums.Provider; using Bit.Core.AdminConsole.Models.Business; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Organizations.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; @@ -25,6 +26,7 @@ using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Test.Billing.Mocks; using Bit.Infrastructure.EntityFramework.AdminConsole.Models.Provider; using Bit.Test.Common.AutoFixture; @@ -200,28 +202,21 @@ public class OrganizationsControllerTests SutProvider sutProvider, User user, Organization organization, - OrganizationUser organizationUser) + OrganizationUser organizationUser, + [Policy(PolicyType.ResetPassword, data: "{\"AutoEnrollEnabled\": true}")] PolicyStatus policy) { - var policy = new Policy - { - Type = PolicyType.ResetPassword, - Enabled = true, - Data = "{\"AutoEnrollEnabled\": true}", - OrganizationId = organization.Id - }; - sutProvider.GetDependency().GetUserByPrincipalAsync(Arg.Any()).Returns(user); sutProvider.GetDependency().GetByIdentifierAsync(organization.Id.ToString()).Returns(organization); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(false); sutProvider.GetDependency().GetByOrganizationAsync(organization.Id, user.Id).Returns(organizationUser); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); + sutProvider.GetDependency().RunAsync(organization.Id, PolicyType.ResetPassword).Returns(policy); var result = await sutProvider.Sut.GetAutoEnrollStatus(organization.Id.ToString()); await sutProvider.GetDependency().Received(1).GetUserByPrincipalAsync(Arg.Any()); await sutProvider.GetDependency().Received(1).GetByIdentifierAsync(organization.Id.ToString()); await sutProvider.GetDependency().Received(0).GetAsync(user.Id); - await sutProvider.GetDependency().Received(1).GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword); + await sutProvider.GetDependency().Received(1).RunAsync(organization.Id, PolicyType.ResetPassword); Assert.True(result.ResetPasswordEnabled); } diff --git a/test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyDetailResponsesTests.cs b/test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyStatusResponsesTests.cs similarity index 62% rename from test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyDetailResponsesTests.cs rename to test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyStatusResponsesTests.cs index 9b863091db..46c6d64bdd 100644 --- a/test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyDetailResponsesTests.cs +++ b/test/Api.Test/AdminConsole/Models/Response/Helpers/PolicyStatusResponsesTests.cs @@ -1,14 +1,13 @@ -using AutoFixture; -using Bit.Api.AdminConsole.Models.Response.Helpers; -using Bit.Core.AdminConsole.Entities; +using Bit.Api.AdminConsole.Models.Response.Helpers; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using NSubstitute; using Xunit; namespace Bit.Api.Test.AdminConsole.Models.Response.Helpers; -public class PolicyDetailResponsesTests +public class PolicyStatusResponsesTests { [Theory] [InlineData(true, false)] @@ -17,19 +16,13 @@ public class PolicyDetailResponsesTests bool policyEnabled, bool expectedCanToggle) { - var fixture = new Fixture(); - - var policy = fixture.Build() - .Without(p => p.Data) - .With(p => p.Type, PolicyType.SingleOrg) - .With(p => p.Enabled, policyEnabled) - .Create(); + var policy = new PolicyStatus(Guid.NewGuid(), PolicyType.SingleOrg) { Enabled = policyEnabled }; var querySub = Substitute.For(); querySub.HasVerifiedDomainsAsync(policy.OrganizationId) .Returns(true); - var result = await policy.GetSingleOrgPolicyDetailResponseAsync(querySub); + var result = await policy.GetSingleOrgPolicyStatusResponseAsync(querySub); Assert.Equal(expectedCanToggle, result.CanToggleState); } @@ -37,18 +30,13 @@ public class PolicyDetailResponsesTests [Fact] public async Task GetSingleOrgPolicyDetailResponseAsync_WhenIsNotSingleOrgType_ThenShouldThrowArgumentException() { - var fixture = new Fixture(); - - var policy = fixture.Build() - .Without(p => p.Data) - .With(p => p.Type, PolicyType.TwoFactorAuthentication) - .Create(); + var policy = new PolicyStatus(Guid.NewGuid(), PolicyType.TwoFactorAuthentication); var querySub = Substitute.For(); querySub.HasVerifiedDomainsAsync(policy.OrganizationId) .Returns(true); - var action = async () => await policy.GetSingleOrgPolicyDetailResponseAsync(querySub); + var action = async () => await policy.GetSingleOrgPolicyStatusResponseAsync(querySub); await Assert.ThrowsAsync("policy", action); } @@ -56,18 +44,13 @@ public class PolicyDetailResponsesTests [Fact] public async Task GetSingleOrgPolicyDetailResponseAsync_WhenIsSingleOrgTypeAndDoesNotHaveVerifiedDomains_ThenShouldBeAbleToToggle() { - var fixture = new Fixture(); - - var policy = fixture.Build() - .Without(p => p.Data) - .With(p => p.Type, PolicyType.SingleOrg) - .Create(); + var policy = new PolicyStatus(Guid.NewGuid(), PolicyType.SingleOrg); var querySub = Substitute.For(); querySub.HasVerifiedDomainsAsync(policy.OrganizationId) .Returns(false); - var result = await policy.GetSingleOrgPolicyDetailResponseAsync(querySub); + var result = await policy.GetSingleOrgPolicyStatusResponseAsync(querySub); Assert.True(result.CanToggleState); } diff --git a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs index 6cddd341d5..665d1e52c1 100644 --- a/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs +++ b/test/Api.Test/Auth/Controllers/AccountsControllerTests.cs @@ -1,7 +1,6 @@ using System.Security.Claims; using Bit.Api.Auth.Controllers; using Bit.Api.Auth.Models.Request.Accounts; -using Bit.Api.KeyManagement.Models.Requests; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Api.Request.Accounts; diff --git a/test/Api.Test/Auth/Models/Request/Accounts/SetInitialPasswordRequestModelTests.cs b/test/Api.Test/Auth/Models/Request/Accounts/SetInitialPasswordRequestModelTests.cs index ce8ba1811e..97e69dacbc 100644 --- a/test/Api.Test/Auth/Models/Request/Accounts/SetInitialPasswordRequestModelTests.cs +++ b/test/Api.Test/Auth/Models/Request/Accounts/SetInitialPasswordRequestModelTests.cs @@ -1,6 +1,5 @@ using System.ComponentModel.DataAnnotations; using Bit.Api.Auth.Models.Request.Accounts; -using Bit.Api.KeyManagement.Models.Requests; using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Entities; using Bit.Core.Enums; diff --git a/test/Api.Test/Auth/Models/Response/EmergencyAccessTakeoverResponseModelTests.cs b/test/Api.Test/Auth/Models/Response/EmergencyAccessTakeoverResponseModelTests.cs new file mode 100644 index 0000000000..1a46cb1956 --- /dev/null +++ b/test/Api.Test/Auth/Models/Response/EmergencyAccessTakeoverResponseModelTests.cs @@ -0,0 +1,129 @@ +using Bit.Api.Auth.Models.Response; +using Bit.Core.Auth.Entities; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Api.Test.Auth.Models.Response; + +public class EmergencyAccessTakeoverResponseModelTests +{ + [Theory] + [BitAutoData] + public void Constructor_EmergencyAccessNull_ThrowsArgumentNullException(User grantor) + { + var exception = Assert.Throws( + () => new EmergencyAccessTakeoverResponseModel(null, grantor)); + Assert.Equal("emergencyAccess", exception.ParamName); + } + + [Theory] + [BitAutoData] + public void Constructor_ValidInputs_SetsAllPropertiesCorrectly( + EmergencyAccess emergencyAccess, User grantor) + { + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal(emergencyAccess.KeyEncrypted, model.KeyEncrypted); + Assert.Equal(grantor.Kdf, model.Kdf); + Assert.Equal(grantor.KdfIterations, model.KdfIterations); + Assert.Equal(grantor.KdfMemory, model.KdfMemory); + Assert.Equal(grantor.KdfParallelism, model.KdfParallelism); + Assert.Equal(grantor.GetMasterPasswordSalt(), model.Salt); + } + + [Theory] + [BitAutoData] + public void Constructor_Salt_EqualsGrantorEmailLowercasedAndTrimmed( + EmergencyAccess emergencyAccess, User grantor) + { + grantor.Email = " TEST@Example.COM "; + + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal("test@example.com", model.Salt); + } + + [Theory] + [InlineData("user@domain.com", "user@domain.com")] + [InlineData("USER@DOMAIN.COM", "user@domain.com")] + [InlineData(" user@domain.com ", "user@domain.com")] + [InlineData(" USER@DOMAIN.COM ", "user@domain.com")] + public void Constructor_SaltWithVariousEmailFormats_NormalizesCorrectly( + string email, string expectedSalt) + { + var emergencyAccess = new EmergencyAccess + { + Id = Guid.NewGuid(), + KeyEncrypted = "test-key-encrypted" + }; + var grantor = new User + { + Id = Guid.NewGuid(), + Email = email, + SecurityStamp = "security-stamp", + ApiKey = "api-key" + }; + + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal(expectedSalt, model.Salt); + } + + [Theory] + [BitAutoData] + public void Constructor_WithPBKDF2_SetsKdfTypeCorrectly( + EmergencyAccess emergencyAccess, User grantor) + { + grantor.Kdf = KdfType.PBKDF2_SHA256; + grantor.KdfIterations = 600000; + grantor.KdfMemory = null; + grantor.KdfParallelism = null; + + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal(KdfType.PBKDF2_SHA256, model.Kdf); + Assert.Equal(600000, model.KdfIterations); + Assert.Null(model.KdfMemory); + Assert.Null(model.KdfParallelism); + } + + [Theory] + [BitAutoData] + public void Constructor_WithArgon2id_SetsAllKdfPropertiesCorrectly( + EmergencyAccess emergencyAccess, User grantor) + { + grantor.Kdf = KdfType.Argon2id; + grantor.KdfIterations = 3; + grantor.KdfMemory = 64; + grantor.KdfParallelism = 4; + + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal(KdfType.Argon2id, model.Kdf); + Assert.Equal(3, model.KdfIterations); + Assert.Equal(64, model.KdfMemory); + Assert.Equal(4, model.KdfParallelism); + } + + [Theory] + [BitAutoData] + public void Constructor_SetsObjectTypeCorrectly( + EmergencyAccess emergencyAccess, User grantor) + { + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor); + + Assert.Equal("emergencyAccessTakeover", model.Object); + } + + [Theory] + [BitAutoData] + public void Constructor_CustomObjectName_SetsObjectTypeCorrectly( + EmergencyAccess emergencyAccess, User grantor) + { + var model = new EmergencyAccessTakeoverResponseModel(emergencyAccess, grantor, "customObject"); + + Assert.Equal("customObject", model.Object); + } +} diff --git a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs index 87334dc085..a7eb4dda5e 100644 --- a/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs +++ b/test/Api.Test/Billing/Controllers/OrganizationSponsorshipsControllerTests.cs @@ -1,6 +1,9 @@ using Bit.Api.Billing.Controllers; using Bit.Api.Models.Request.Organizations; using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Billing.Enums; using Bit.Core.Context; using Bit.Core.Entities; @@ -10,6 +13,7 @@ using Bit.Core.Models.Data; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -82,7 +86,9 @@ public class OrganizationSponsorshipsControllerTests [BitAutoData] public async Task RedeemSponsorship_NotSponsoredOrgOwner_Success(string sponsorshipToken, User user, OrganizationSponsorship sponsorship, Organization sponsoringOrganization, - OrganizationSponsorshipRedeemRequestModel model, SutProvider sutProvider) + OrganizationSponsorshipRedeemRequestModel model, + [Policy(PolicyType.FreeFamiliesSponsorshipPolicy, false)] PolicyStatus policy, + SutProvider sutProvider) { sutProvider.GetDependency().UserId.Returns(user.Id); sutProvider.GetDependency().GetUserByIdAsync(user.Id) @@ -91,6 +97,9 @@ public class OrganizationSponsorshipsControllerTests user.Email).Returns((true, sponsorship)); sutProvider.GetDependency().OrganizationOwner(model.SponsoredOrganizationId).Returns(true); sutProvider.GetDependency().GetByIdAsync(model.SponsoredOrganizationId).Returns(sponsoringOrganization); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.FreeFamiliesSponsorshipPolicy) + .Returns(policy); await sutProvider.Sut.RedeemSponsorship(sponsorshipToken, model); @@ -101,14 +110,18 @@ public class OrganizationSponsorshipsControllerTests [Theory] [BitAutoData] public async Task PreValidateSponsorshipToken_ValidatesToken_Success(string sponsorshipToken, User user, - OrganizationSponsorship sponsorship, SutProvider sutProvider) + OrganizationSponsorship sponsorship, + [Policy(PolicyType.FreeFamiliesSponsorshipPolicy, false)] PolicyStatus policy, + SutProvider sutProvider) { sutProvider.GetDependency().UserId.Returns(user.Id); sutProvider.GetDependency().GetUserByIdAsync(user.Id) .Returns(user); sutProvider.GetDependency() .ValidateRedemptionTokenAsync(sponsorshipToken, user.Email).Returns((true, sponsorship)); - + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.FreeFamiliesSponsorshipPolicy) + .Returns(policy); await sutProvider.Sut.PreValidateSponsorshipToken(sponsorshipToken); await sutProvider.GetDependency().Received(1) diff --git a/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs new file mode 100644 index 0000000000..5ed4182a5d --- /dev/null +++ b/test/Api.Test/Billing/Models/Requests/PreviewPremiumUpgradeProrationRequestTests.cs @@ -0,0 +1,56 @@ +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.PreviewInvoice; +using Bit.Core.Billing.Enums; +using Xunit; + +namespace Bit.Api.Test.Billing.Models.Requests; + +public class PreviewPremiumUpgradeProrationRequestTests +{ + [Theory] + [InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)] + [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)] + [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)] + public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType) + { + // Arrange + var sut = new PreviewPremiumUpgradeProrationRequest + { + TargetProductTierType = tierType, + BillingAddress = new MinimalBillingAddressRequest + { + Country = "US", + PostalCode = "12345" + } + }; + + // Act + var (planType, billingAddress) = sut.ToDomain(); + + // Assert + Assert.Equal(expectedPlanType, planType); + Assert.Equal("US", billingAddress.Country); + Assert.Equal("12345", billingAddress.PostalCode); + } + + [Theory] + [InlineData(ProductTierType.Free)] + [InlineData(ProductTierType.TeamsStarter)] + public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType) + { + // Arrange + var sut = new PreviewPremiumUpgradeProrationRequest + { + TargetProductTierType = tierType, + BillingAddress = new MinimalBillingAddressRequest + { + Country = "US", + PostalCode = "12345" + } + }; + + // Act & Assert + var exception = Assert.Throws(() => sut.ToDomain()); + Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message); + } +} diff --git a/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs new file mode 100644 index 0000000000..2d3bdb7b14 --- /dev/null +++ b/test/Api.Test/Billing/Models/Requests/UpgradePremiumToOrganizationRequestTests.cs @@ -0,0 +1,62 @@ +using Bit.Api.Billing.Models.Requests.Payment; +using Bit.Api.Billing.Models.Requests.Premium; +using Bit.Core.Billing.Enums; +using Xunit; + +namespace Bit.Api.Test.Billing.Models.Requests; + +public class UpgradePremiumToOrganizationRequestTests +{ + [Theory] + [InlineData(ProductTierType.Families, PlanType.FamiliesAnnually)] + [InlineData(ProductTierType.Teams, PlanType.TeamsAnnually)] + [InlineData(ProductTierType.Enterprise, PlanType.EnterpriseAnnually)] + public void ToDomain_ValidTierTypes_ReturnsPlanType(ProductTierType tierType, PlanType expectedPlanType) + { + // Arrange + var sut = new UpgradePremiumToOrganizationRequest + { + OrganizationName = "Test Organization", + Key = "encrypted-key", + TargetProductTierType = tierType, + BillingAddress = new MinimalBillingAddressRequest + { + Country = "US", + PostalCode = "12345" + } + }; + + // Act + var (organizationName, key, planType, billingAddress) = sut.ToDomain(); + + // Assert + Assert.Equal("Test Organization", organizationName); + Assert.Equal("encrypted-key", key); + Assert.Equal(expectedPlanType, planType); + Assert.Equal("US", billingAddress.Country); + Assert.Equal("12345", billingAddress.PostalCode); + } + + [Theory] + [InlineData(ProductTierType.Free)] + [InlineData(ProductTierType.TeamsStarter)] + public void ToDomain_InvalidTierTypes_ThrowsInvalidOperationException(ProductTierType tierType) + { + // Arrange + var sut = new UpgradePremiumToOrganizationRequest + { + OrganizationName = "Test Organization", + Key = "encrypted-key", + TargetProductTierType = tierType, + BillingAddress = new MinimalBillingAddressRequest + { + Country = "US", + PostalCode = "12345" + } + }; + + // Act & Assert + var exception = Assert.Throws(() => sut.ToDomain()); + Assert.Contains($"Cannot upgrade Premium subscription to {tierType} plan", exception.Message); + } +} diff --git a/test/Api.Test/Controllers/CollectionsControllerTests.cs b/test/Api.Test/Controllers/CollectionsControllerTests.cs index 33b7e20327..c345e3602f 100644 --- a/test/Api.Test/Controllers/CollectionsControllerTests.cs +++ b/test/Api.Test/Controllers/CollectionsControllerTests.cs @@ -107,7 +107,7 @@ public class CollectionsControllerTests await sutProvider.Sut.GetManyWithDetails(organization.Id); - await sutProvider.GetDependency().Received(1).GetManyByOrganizationIdWithPermissionsAsync(organization.Id, userId, true); + await sutProvider.GetDependency().Received(1).GetManySharedByOrganizationIdWithPermissionsAsync(organization.Id, userId, true); } [Theory, BitAutoData] @@ -143,12 +143,12 @@ public class CollectionsControllerTests .Returns(AuthorizationResult.Success()); sutProvider.GetDependency() - .GetManyByOrganizationIdWithPermissionsAsync(organization.Id, userId, true) + .GetManySharedByOrganizationIdWithPermissionsAsync(organization.Id, userId, true) .Returns(collections); var response = await sutProvider.Sut.GetManyWithDetails(organization.Id); - await sutProvider.GetDependency().Received(1).GetManyByOrganizationIdWithPermissionsAsync(organization.Id, userId, true); + await sutProvider.GetDependency().Received(1).GetManySharedByOrganizationIdWithPermissionsAsync(organization.Id, userId, true); Assert.Single(response.Data); Assert.All(response.Data, c => Assert.Equal(organization.Id, c.OrganizationId)); Assert.All(response.Data, c => Assert.Equal(managedCollection.Id, c.Id)); diff --git a/test/Api.Test/Controllers/PoliciesControllerTests.cs b/test/Api.Test/Controllers/PoliciesControllerTests.cs index efb9f7aaa9..03ab20ec28 100644 --- a/test/Api.Test/Controllers/PoliciesControllerTests.cs +++ b/test/Api.Test/Controllers/PoliciesControllerTests.cs @@ -49,7 +49,7 @@ public class PoliciesControllerTests sutProvider.GetDependency() .GetProperUserId(Arg.Any()) - .Returns((Guid?)userId); + .Returns(userId); sutProvider.GetDependency() .GetByOrganizationAsync(orgId, userId) @@ -95,7 +95,7 @@ public class PoliciesControllerTests // Arrange sutProvider.GetDependency() .GetProperUserId(Arg.Any()) - .Returns((Guid?)userId); + .Returns(userId); sutProvider.GetDependency() .GetByOrganizationAsync(orgId, userId) @@ -113,7 +113,7 @@ public class PoliciesControllerTests // Arrange sutProvider.GetDependency() .GetProperUserId(Arg.Any()) - .Returns((Guid?)userId); + .Returns(userId); sutProvider.GetDependency() .GetByOrganizationAsync(orgId, userId) @@ -135,7 +135,7 @@ public class PoliciesControllerTests // Arrange sutProvider.GetDependency() .GetProperUserId(Arg.Any()) - .Returns((Guid?)userId); + .Returns(userId); sutProvider.GetDependency() .GetByOrganizationAsync(orgId, userId) @@ -186,59 +186,35 @@ public class PoliciesControllerTests [Theory] [BitAutoData] public async Task Get_WhenUserCanManagePolicies_WithExistingType_ReturnsExistingPolicy( - SutProvider sutProvider, Guid orgId, Policy policy, int type) + SutProvider sutProvider, Guid orgId, PolicyStatus policy, PolicyType type) { // Arrange sutProvider.GetDependency() .ManagePolicies(orgId) .Returns(true); - policy.Type = (PolicyType)type; + policy.Type = type; policy.Enabled = true; policy.Data = null; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(orgId, (PolicyType)type) + sutProvider.GetDependency() + .RunAsync(orgId, type) .Returns(policy); // Act var result = await sutProvider.Sut.Get(orgId, type); // Assert - Assert.IsType(result); - Assert.Equal(policy.Id, result.Id); + Assert.IsType(result); Assert.Equal(policy.Type, result.Type); Assert.Equal(policy.Enabled, result.Enabled); Assert.Equal(policy.OrganizationId, result.OrganizationId); } - [Theory] - [BitAutoData] - public async Task Get_WhenUserCanManagePolicies_WithNonExistingType_ReturnsDefaultPolicy( - SutProvider sutProvider, Guid orgId, int type) - { - // Arrange - sutProvider.GetDependency() - .ManagePolicies(orgId) - .Returns(true); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(orgId, (PolicyType)type) - .Returns((Policy)null); - - // Act - var result = await sutProvider.Sut.Get(orgId, type); - - // Assert - Assert.IsType(result); - Assert.Equal(result.Type, (PolicyType)type); - Assert.False(result.Enabled); - } - [Theory] [BitAutoData] public async Task Get_WhenUserCannotManagePolicies_ThrowsNotFoundException( - SutProvider sutProvider, Guid orgId, int type) + SutProvider sutProvider, Guid orgId, PolicyType type) { // Arrange sutProvider.GetDependency() diff --git a/test/Api.Test/Controllers/SsoCookieVendorControllerTests.cs b/test/Api.Test/Controllers/SsoCookieVendorControllerTests.cs new file mode 100644 index 0000000000..1e954e68ff --- /dev/null +++ b/test/Api.Test/Controllers/SsoCookieVendorControllerTests.cs @@ -0,0 +1,362 @@ +#nullable enable + +using Bit.Api.Controllers; +using Bit.Core.Settings; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using NSubstitute; +using Xunit; + +namespace Bit.Api.Test.Controllers; + +public class SsoCookieVendorControllerTests : IDisposable +{ + private readonly SsoCookieVendorController _sut; + private readonly GlobalSettings _globalSettings; + + public SsoCookieVendorControllerTests() + { + _globalSettings = new GlobalSettings + { + Communication = new GlobalSettings.CommunicationSettings + { + Bootstrap = "ssoCookieVendor", + SsoCookieVendor = new GlobalSettings.SsoCookieVendorSettings + { + CookieName = "test-cookie" + } + } + }; + _sut = new SsoCookieVendorController(_globalSettings); + } + + public void Dispose() + { + _sut?.Dispose(); + } + + private void MockHttpContextWithCookies(Dictionary cookies) + { + var httpContext = new DefaultHttpContext(); + var cookieCollection = Substitute.For(); + + // Mock the TryGetValue method + cookieCollection.TryGetValue(Arg.Any(), out Arg.Any()) + .Returns(callInfo => + { + var key = callInfo.ArgAt(0); + if (cookies.TryGetValue(key, out var value)) + { + callInfo[1] = value; + return true; + } + callInfo[1] = null; + return false; + }); + + // Mock the indexer if needed + cookieCollection[Arg.Any()].Returns(callInfo => + { + var key = callInfo.ArgAt(0); + return cookies.TryGetValue(key, out var value) ? value : null; + }); + + httpContext.Request.Cookies = cookieCollection; + _sut.ControllerContext = new ControllerContext { HttpContext = httpContext }; + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData("none")] + public void Get_WhenBootstrapNotConfigured_Returns404(string? bootstrap) + { + // Arrange +#nullable disable + _globalSettings.Communication.Bootstrap = bootstrap; +#nullable restore + MockHttpContextWithCookies([]); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenCookieNameNotConfigured_Returns500() + { + // Arrange + _globalSettings.Communication.SsoCookieVendor.CookieName = string.Empty; + MockHttpContextWithCookies([]); + + // Act + var result = _sut.Get(); + + // Assert + var statusCodeResult = Assert.IsType(result); + Assert.Equal(500, statusCodeResult.StatusCode); + } + + [Fact] + public void Get_WhenCookieNameIsEmpty_Returns500() + { + // Arrange + _globalSettings.Communication.SsoCookieVendor.CookieName = ""; + MockHttpContextWithCookies([]); + + // Act + var result = _sut.Get(); + + // Assert + var statusCodeResult = Assert.IsType(result); + Assert.Equal(500, statusCodeResult.StatusCode); + } + + [Fact] + public void Get_WhenSingleCookieExists_ReturnsRedirectWithCorrectUri() + { + // Arrange + var cookies = new Dictionary + { + { "test-cookie", "my-token-value-123" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + Assert.Equal("bitwarden://sso_cookie_vendor?test-cookie=my-token-value-123&d=1", redirectResult.Url); + } + + [Fact] + public void Get_WhenSingleCookieHasSpecialCharacters_EncodesCorrectly() + { + // Arrange + var cookies = new Dictionary + { + { "test-cookie", "value with spaces & special=chars!" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + Assert.Contains("value%20with%20spaces", redirectResult.Url); + Assert.Contains("%26", redirectResult.Url); // & encoded + Assert.Contains("%3D", redirectResult.Url); // = encoded + Assert.Contains("%21", redirectResult.Url); // ! encoded + } + + [Fact] + public void Get_WhenShardedCookiesExist_ReturnsRedirectWithShardedUri() + { + // Arrange + var cookies = new Dictionary + { + { "test-cookie-0", "part1" }, + { "test-cookie-1", "part2" }, + { "test-cookie-2", "part3" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + Assert.StartsWith("bitwarden://sso_cookie_vendor?", redirectResult.Url); + Assert.Contains("test-cookie-0=part1", redirectResult.Url); + Assert.Contains("test-cookie-1=part2", redirectResult.Url); + Assert.Contains("test-cookie-2=part3", redirectResult.Url); + Assert.EndsWith("d=1", redirectResult.Url); + } + + [Fact] + public void Get_WhenShardedCookiesWithGap_StopsAtFirstGap() + { + // Arrange + var cookies = new Dictionary + { + { "test-cookie-0", "part0" }, + { "test-cookie-1", "part1" }, + // Missing test-cookie-2 + { "test-cookie-3", "part3" }, + { "test-cookie-4", "part4" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + Assert.Contains("test-cookie-0=part0", redirectResult.Url); + Assert.Contains("test-cookie-1=part1", redirectResult.Url); + Assert.DoesNotContain("test-cookie-3", redirectResult.Url); + Assert.DoesNotContain("test-cookie-4", redirectResult.Url); + Assert.EndsWith("d=1", redirectResult.Url); + } + + [Fact] + public void Get_WhenOnlyGappedShardsExist_Returns404() + { + // Arrange - only test-cookie-2 exists, not test-cookie-0 or test-cookie-1 + var cookies = new Dictionary + { + { "test-cookie-2", "part2" }, + { "test-cookie-3", "part3" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenNoCookiesFound_Returns404() + { + // Arrange + MockHttpContextWithCookies([]); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenUnrelatedCookiesExist_Returns404() + { + // Arrange + var cookies = new Dictionary + { + { "other-cookie", "value" }, + { "another-cookie", "value2" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenUriExceedsMaxLength_Returns400() + { + // Arrange - create a very long cookie value that will exceed 8192 characters + // URI format: "bitwarden://sso_cookie_vendor?test-cookie={value}" + // Base URI length is about 43 characters, so we need value > 8149 + var longValue = new string('a', 8200); + var cookies = new Dictionary + { + { "test-cookie", longValue } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenSingleCookiePreferredOverSharded_ReturnsSingleCookie() + { + // Arrange - both single and sharded cookies exist + var cookies = new Dictionary + { + { "test-cookie", "single-value" }, + { "test-cookie-0", "shard0" }, + { "test-cookie-1", "shard1" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + Assert.Equal("bitwarden://sso_cookie_vendor?test-cookie=single-value&d=1", redirectResult.Url); + } + + [Fact] + public void Get_WhenEmptyCookieValue_TreatsAsNotFound() + { + // Arrange + var cookies = new Dictionary + { + { "test-cookie", "" } + }; + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + Assert.IsType(result); + } + + [Fact] + public void Get_WhenShardedCookiesHaveMaxCount_ProcessesAllShards() + { + // Arrange - create 20 sharded cookies (MaxShardCount) + var cookies = new Dictionary(); + for (var i = 0; i < 20; i++) + { + cookies[$"test-cookie-{i}"] = $"part{i}"; + } + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + for (var i = 0; i < 20; i++) + { + Assert.Contains($"test-cookie-{i}=part{i}", redirectResult.Url); + } + Assert.EndsWith("d=1", redirectResult.Url); + } + + [Fact] + public void Get_WhenShardedCookiesExceedMaxCount_OnlyProcessesFirst20() + { + // Arrange - create 25 sharded cookies (more than MaxShardCount of 20) + var cookies = new Dictionary(); + for (var i = 0; i < 25; i++) + { + cookies[$"test-cookie-{i}"] = $"part{i}"; + } + MockHttpContextWithCookies(cookies); + + // Act + var result = _sut.Get(); + + // Assert + var redirectResult = Assert.IsType(result); + // Should contain first 20 + for (var i = 0; i < 20; i++) + { + Assert.Contains($"test-cookie-{i}=part{i}", redirectResult.Url); + } + // Should NOT contain 21-25 + for (var i = 20; i < 25; i++) + { + Assert.DoesNotContain($"test-cookie-{i}=part{i}", redirectResult.Url); + } + } +} diff --git a/test/Api.Test/KeyManagement/Validators/EmergencyAccessRotationValidatorTests.cs b/test/Api.Test/KeyManagement/Validators/EmergencyAccessRotationValidatorTests.cs index e00129fd89..a69576f9dc 100644 --- a/test/Api.Test/KeyManagement/Validators/EmergencyAccessRotationValidatorTests.cs +++ b/test/Api.Test/KeyManagement/Validators/EmergencyAccessRotationValidatorTests.cs @@ -30,7 +30,7 @@ public class EmergencyAccessRotationValidatorTests KeyEncrypted = e.KeyEncrypted, Type = e.Type }).ToList(); - userEmergencyAccess.Add(new EmergencyAccessDetails { Id = Guid.NewGuid(), KeyEncrypted = "TestKey" }); + userEmergencyAccess.Add(new EmergencyAccessDetails { Id = Guid.NewGuid(), GrantorEmail = "grantor@example.com", KeyEncrypted = "TestKey" }); sutProvider.GetDependency().GetManyDetailsByGrantorIdAsync(user.Id) .Returns(userEmergencyAccess); diff --git a/test/Api.Test/KeyManagement/Validators/OrganizationUserRotationValidatorTests.cs b/test/Api.Test/KeyManagement/Validators/OrganizationUserRotationValidatorTests.cs index 964c801903..a939636fc2 100644 --- a/test/Api.Test/KeyManagement/Validators/OrganizationUserRotationValidatorTests.cs +++ b/test/Api.Test/KeyManagement/Validators/OrganizationUserRotationValidatorTests.cs @@ -69,6 +69,44 @@ public class OrganizationUserRotationValidatorTests Assert.Empty(result); } + [Theory] + [BitAutoData([null])] + [BitAutoData("")] + public async Task ValidateAsync_OrgUsersWithNullOrEmptyResetPasswordKey_FiltersOutInvalidKeys( + string? invalidResetPasswordKey, + SutProvider sutProvider, User user, + ResetPasswordWithOrgIdRequestModel validResetPasswordKey) + { + // Arrange + var existingUserResetPassword = new List + { + // Valid org user with reset password key + new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = validResetPasswordKey.OrganizationId, + ResetPasswordKey = validResetPasswordKey.ResetPasswordKey + }, + // Invalid org user with null or empty reset password key - should be filtered out + new OrganizationUser + { + Id = Guid.NewGuid(), + OrganizationId = Guid.NewGuid(), + ResetPasswordKey = invalidResetPasswordKey + } + }; + sutProvider.GetDependency().GetManyByUserAsync(user.Id) + .Returns(existingUserResetPassword); + + // Act + var result = await sutProvider.Sut.ValidateAsync(user, new[] { validResetPasswordKey }); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + Assert.Equal(validResetPasswordKey.OrganizationId, result[0].OrganizationId); + } + [Theory] [BitAutoData] public async Task ValidateAsync_MissingResetPassword_Throws( diff --git a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs index 9ca641a28e..a8465ed0f6 100644 --- a/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs +++ b/test/Api.Test/Tools/Controllers/ImportCiphersControllerTests.cs @@ -806,63 +806,6 @@ public class ImportCiphersControllerTests Arg.Any()); } - [Theory, BitAutoData] - public async Task PostImportOrganization_ThrowsException_WhenAnyCipherIsArchived( - SutProvider sutProvider, - IFixture fixture, - User user - ) - { - var orgId = Guid.NewGuid(); - - sutProvider.GetDependency() - .SelfHosted = false; - sutProvider.GetDependency() - .ImportCiphersLimitation = _organizationCiphersLimitations; - - SetupUserService(sutProvider, user); - - var ciphers = fixture.Build() - .With(_ => _.ArchivedDate, DateTime.UtcNow) - .CreateMany(2).ToArray(); - - var request = new ImportOrganizationCiphersRequestModel - { - Collections = new List().ToArray(), - Ciphers = ciphers, - CollectionRelationships = new List>().ToArray(), - }; - - sutProvider.GetDependency() - .AccessImportExport(Arg.Any()) - .Returns(false); - - sutProvider.GetDependency() - .AuthorizeAsync(Arg.Any(), - Arg.Any>(), - Arg.Is>(reqs => - reqs.Contains(BulkCollectionOperations.ImportCiphers))) - .Returns(AuthorizationResult.Failed()); - - sutProvider.GetDependency() - .AuthorizeAsync(Arg.Any(), - Arg.Any>(), - Arg.Is>(reqs => - reqs.Contains(BulkCollectionOperations.Create))) - .Returns(AuthorizationResult.Success()); - - sutProvider.GetDependency() - .GetManyByOrganizationIdAsync(orgId) - .Returns(new List()); - - var exception = await Assert.ThrowsAsync(async () => - { - await sutProvider.Sut.PostImportOrganization(orgId.ToString(), request); - }); - - Assert.Equal("You cannot import archived items into an organization.", exception.Message); - } - private static void SetupUserService(SutProvider sutProvider, User user) { // This is a workaround for the NSubstitute issue with ambiguous arguments diff --git a/test/Api.Test/Tools/Controllers/SendsControllerTests.cs b/test/Api.Test/Tools/Controllers/SendsControllerTests.cs index e3a9ba4435..3d77ac2343 100644 --- a/test/Api.Test/Tools/Controllers/SendsControllerTests.cs +++ b/test/Api.Test/Tools/Controllers/SendsControllerTests.cs @@ -903,6 +903,106 @@ public class SendsControllerTests : IDisposable Assert.Equal(creator.Email, response.CreatorIdentifier); } + [Theory, AutoData] + public async Task AccessUsingAuth_WithDisabledSend_ThrowsNotFoundException(Guid sendId) + { + var send = new Send + { + Id = sendId, + Type = SendType.Text, + Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + Disabled = true, + AccessCount = 0, + MaxAccessCount = null + }; + var user = CreateUserWithSendIdClaim(sendId); + _sut.ControllerContext = CreateControllerContextWithUser(user); + _sendRepository.GetByIdAsync(sendId).Returns(send); + + await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); + + await _sendRepository.Received(1).GetByIdAsync(sendId); + await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any()); + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Theory, AutoData] + public async Task AccessUsingAuth_WithMaxAccessCountReached_ThrowsNotFoundException(Guid sendId) + { + var send = new Send + { + Id = sendId, + Type = SendType.Text, + Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + Disabled = false, + AccessCount = 10, + MaxAccessCount = 10 + }; + var user = CreateUserWithSendIdClaim(sendId); + _sut.ControllerContext = CreateControllerContextWithUser(user); + _sendRepository.GetByIdAsync(sendId).Returns(send); + + await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); + + await _sendRepository.Received(1).GetByIdAsync(sendId); + await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any()); + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Theory, AutoData] + public async Task AccessUsingAuth_WithExpiredSend_ThrowsNotFoundException(Guid sendId) + { + var send = new Send + { + Id = sendId, + Type = SendType.Text, + Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired yesterday + Disabled = false, + AccessCount = 0, + MaxAccessCount = null + }; + var user = CreateUserWithSendIdClaim(sendId); + _sut.ControllerContext = CreateControllerContextWithUser(user); + _sendRepository.GetByIdAsync(sendId).Returns(send); + + await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); + + await _sendRepository.Received(1).GetByIdAsync(sendId); + await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any()); + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + + [Theory, AutoData] + public async Task AccessUsingAuth_WithDeletionDatePassed_ThrowsNotFoundException(Guid sendId) + { + var send = new Send + { + Id = sendId, + Type = SendType.Text, + Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), + DeletionDate = DateTime.UtcNow.AddDays(-1), // Deletion date has passed + ExpirationDate = null, + Disabled = false, + AccessCount = 0, + MaxAccessCount = null + }; + var user = CreateUserWithSendIdClaim(sendId); + _sut.ControllerContext = CreateControllerContextWithUser(user); + _sendRepository.GetByIdAsync(sendId).Returns(send); + + await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); + + await _sendRepository.Received(1).GetByIdAsync(sendId); + await _userService.DidNotReceive().GetUserByIdAsync(Arg.Any()); + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + } + [Theory, AutoData] public async Task GetSendFileDownloadDataUsingAuth_WithValidFileId_ReturnsDownloadUrl( Guid sendId, string fileId, string expectedUrl) @@ -922,7 +1022,8 @@ public class SendsControllerTests : IDisposable var user = CreateUserWithSendIdClaim(sendId); _sut.ControllerContext = CreateControllerContextWithUser(user); _sendRepository.GetByIdAsync(sendId).Returns(send); - _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl); + _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId) + .Returns((expectedUrl, SendAccessResult.Granted)); var result = await _sut.GetSendFileDownloadDataUsingAuth(fileId); @@ -932,7 +1033,7 @@ public class SendsControllerTests : IDisposable Assert.Equal(fileId, response.Id); Assert.Equal(expectedUrl, response.Url); await _sendRepository.Received(1).GetByIdAsync(sendId); - await _sendFileStorageService.Received(1).GetSendFileDownloadUrlAsync(send, fileId); + await _nonAnonymousSendCommand.Received(1).GetSendFileDownloadUrlAsync(send, fileId); } [Theory, AutoData] @@ -948,175 +1049,20 @@ public class SendsControllerTests : IDisposable Assert.Equal("Could not locate send", exception.Message); await _sendRepository.Received(1).GetByIdAsync(sendId); - await _sendFileStorageService.DidNotReceive() + await _nonAnonymousSendCommand.DidNotReceive() .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); } [Theory, AutoData] - public async Task GetSendFileDownloadDataUsingAuth_WithTextSend_StillReturnsResponse( - Guid sendId, string fileId, string expectedUrl) - { - var send = new Send - { - Id = sendId, - Type = SendType.Text, - Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), - DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = null, - Disabled = false, - AccessCount = 0, - MaxAccessCount = null - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl); - - var result = await _sut.GetSendFileDownloadDataUsingAuth(fileId); - - Assert.NotNull(result); - var objectResult = Assert.IsType(result); - var response = Assert.IsType(objectResult.Value); - Assert.Equal(fileId, response.Id); - Assert.Equal(expectedUrl, response.Url); - } - - #region AccessUsingAuth Validation Tests - - [Theory, AutoData] - public async Task AccessUsingAuth_WithExpiredSend_ThrowsNotFoundException(Guid sendId) - { - var send = new Send - { - Id = sendId, - UserId = Guid.NewGuid(), - Type = SendType.Text, - Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), - DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired yesterday - Disabled = false, - AccessCount = 0, - MaxAccessCount = null - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - [Theory, AutoData] - public async Task AccessUsingAuth_WithDeletedSend_ThrowsNotFoundException(Guid sendId) - { - var send = new Send - { - Id = sendId, - UserId = Guid.NewGuid(), - Type = SendType.Text, - Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), - DeletionDate = DateTime.UtcNow.AddDays(-1), // Should have been deleted yesterday - ExpirationDate = null, - Disabled = false, - AccessCount = 0, - MaxAccessCount = null - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - [Theory, AutoData] - public async Task AccessUsingAuth_WithDisabledSend_ThrowsNotFoundException(Guid sendId) - { - var send = new Send - { - Id = sendId, - UserId = Guid.NewGuid(), - Type = SendType.Text, - Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), - DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = null, - Disabled = true, // Disabled - AccessCount = 0, - MaxAccessCount = null - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - [Theory, AutoData] - public async Task AccessUsingAuth_WithAccessCountExceeded_ThrowsNotFoundException(Guid sendId) - { - var send = new Send - { - Id = sendId, - UserId = Guid.NewGuid(), - Type = SendType.Text, - Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), - DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = null, - Disabled = false, - AccessCount = 5, - MaxAccessCount = 5 // Limit reached - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.AccessUsingAuth()); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - #endregion - - #region GetSendFileDownloadDataUsingAuth Validation Tests - - [Theory, AutoData] - public async Task GetSendFileDownloadDataUsingAuth_WithExpiredSend_ThrowsNotFoundException( + public async Task GetSendFileDownloadDataUsingAuth_WithTextSend_ThrowsBadRequestException( Guid sendId, string fileId) { var send = new Send { Id = sendId, - Type = SendType.File, - Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")), + Type = SendType.Text, + Data = JsonSerializer.Serialize(new SendTextData("Test", "Notes", "Text", false)), DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired - Disabled = false, - AccessCount = 0, - MaxAccessCount = null - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.GetSendFileDownloadDataUsingAuth(fileId)); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - [Theory, AutoData] - public async Task GetSendFileDownloadDataUsingAuth_WithDeletedSend_ThrowsNotFoundException( - Guid sendId, string fileId) - { - var send = new Send - { - Id = sendId, - Type = SendType.File, - Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")), - DeletionDate = DateTime.UtcNow.AddDays(-1), // Deleted ExpirationDate = null, Disabled = false, AccessCount = 0, @@ -1125,61 +1071,46 @@ public class SendsControllerTests : IDisposable var user = CreateUserWithSendIdClaim(sendId); _sut.ControllerContext = CreateControllerContextWithUser(user); _sendRepository.GetByIdAsync(sendId).Returns(send); + _nonAnonymousSendCommand + .When(x => x.GetSendFileDownloadUrlAsync(send, fileId)) + .Do(x => throw new BadRequestException("Can only get a download URL for a file type of Send")); - await Assert.ThrowsAsync(() => _sut.GetSendFileDownloadDataUsingAuth(fileId)); + var exception = await Assert.ThrowsAsync( + () => _sut.GetSendFileDownloadDataUsingAuth(fileId)); + Assert.Equal("Can only get a download URL for a file type of Send", exception.Message); await _sendRepository.Received(1).GetByIdAsync(sendId); + await _nonAnonymousSendCommand.Received(1).GetSendFileDownloadUrlAsync(send, fileId); } [Theory, AutoData] - public async Task GetSendFileDownloadDataUsingAuth_WithDisabledSend_ThrowsNotFoundException( + public async Task GetSendFileDownloadDataUsingAuth_WithAccessDenied_ThrowsNotFoundException( Guid sendId, string fileId) { + var fileData = new SendFileData("Test File", "Notes", "document.pdf") { Id = fileId, Size = 2048 }; var send = new Send { Id = sendId, Type = SendType.File, - Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")), + Data = JsonSerializer.Serialize(fileData), DeletionDate = DateTime.UtcNow.AddDays(7), ExpirationDate = null, - Disabled = true, // Disabled + Disabled = false, AccessCount = 0, MaxAccessCount = null }; var user = CreateUserWithSendIdClaim(sendId); _sut.ControllerContext = CreateControllerContextWithUser(user); _sendRepository.GetByIdAsync(sendId).Returns(send); + _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId) + .Returns((null, SendAccessResult.Denied)); await Assert.ThrowsAsync(() => _sut.GetSendFileDownloadDataUsingAuth(fileId)); await _sendRepository.Received(1).GetByIdAsync(sendId); + await _nonAnonymousSendCommand.Received(1).GetSendFileDownloadUrlAsync(send, fileId); } - [Theory, AutoData] - public async Task GetSendFileDownloadDataUsingAuth_WithAccessCountExceeded_ThrowsNotFoundException( - Guid sendId, string fileId) - { - var send = new Send - { - Id = sendId, - Type = SendType.File, - Data = JsonSerializer.Serialize(new SendFileData("Test", "Notes", "file.pdf")), - DeletionDate = DateTime.UtcNow.AddDays(7), - ExpirationDate = null, - Disabled = false, - AccessCount = 10, - MaxAccessCount = 10 // Limit reached - }; - var user = CreateUserWithSendIdClaim(sendId); - _sut.ControllerContext = CreateControllerContextWithUser(user); - _sendRepository.GetByIdAsync(sendId).Returns(send); - - await Assert.ThrowsAsync(() => _sut.GetSendFileDownloadDataUsingAuth(fileId)); - - await _sendRepository.Received(1).GetByIdAsync(sendId); - } - - #endregion #endregion diff --git a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs index 238053464c..6fba9730a7 100644 --- a/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs +++ b/test/Api.Test/Vault/Controllers/CiphersControllerTests.cs @@ -158,9 +158,9 @@ public class CiphersControllerTests [BitAutoData(OrganizationUserType.Custom, false, false)] public async Task CanEditCiphersAsAdminAsync_FlexibleCollections_Success( OrganizationUserType userType, bool allowAdminsAccessToAllItems, bool shouldSucceed, - CurrentContextOrganization organization, Guid userId, CipherDetails cipherDetails, SutProvider sutProvider) + CurrentContextOrganization organization, Guid userId, CipherOrganizationDetails cipherOrgDetails, SutProvider sutProvider) { - cipherDetails.OrganizationId = organization.Id; + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = userType; if (userType == OrganizationUserType.Custom) { @@ -171,9 +171,9 @@ public class CiphersControllerTests sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherOrgDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { @@ -183,13 +183,13 @@ public class CiphersControllerTests if (shouldSucceed) { - await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + await sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id); await sutProvider.GetDependency().ReceivedWithAnyArgs() .DeleteAsync(default, default); } else { - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id)); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .DeleteAsync(default, default); } @@ -199,25 +199,23 @@ public class CiphersControllerTests [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task DeleteAdmin_WithOwnerOrAdmin_WithManagePermission_DeletesCipher( - OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + OrganizationUserType organizationUserType, CipherOrganizationDetails cipherOrgDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipherDetails.UserId = null; - cipherDetails.OrganizationId = organization.Id; - cipherDetails.Edit = true; - cipherDetails.Manage = true; + cipherOrgDetails.UserId = null; + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(new List { - cipherDetails + new CipherDetails(cipherOrgDetails) { Edit = true, Manage = true } }); sutProvider.GetDependency() .GetOrganizationAbilityAsync(organization.Id) @@ -227,34 +225,35 @@ public class CiphersControllerTests LimitItemDeletion = true }); - await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + await sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync( + Arg.Is(c => c.Id == cipherOrgDetails.Id && c.OrganizationId == cipherOrgDetails.OrganizationId), + userId, + true); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task DeleteAdmin_WithOwnerOrAdmin_WithoutManagePermission_ThrowsNotFoundException( - OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + OrganizationUserType organizationUserType, CipherOrganizationDetails cipherOrgDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipherDetails.UserId = null; - cipherDetails.OrganizationId = organization.Id; - cipherDetails.Edit = true; - cipherDetails.Manage = false; + cipherOrgDetails.UserId = null; + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); sutProvider.GetDependency() .GetManyByUserIdAsync(userId) .Returns(new List { - cipherDetails + new CipherDetails(cipherOrgDetails) { Edit = true, Manage = false } }); sutProvider.GetDependency() .GetOrganizationAbilityAsync(organization.Id) @@ -264,7 +263,7 @@ public class CiphersControllerTests LimitItemDeletion = true }); - await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherDetails.Id)); + await Assert.ThrowsAsync(() => sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id)); await sutProvider.GetDependency().DidNotReceive().DeleteAsync(Arg.Any(), Arg.Any(), Arg.Any()); } @@ -273,21 +272,21 @@ public class CiphersControllerTests [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task DeleteAdmin_WithOwnerOrAdmin_WithAccessToUnassignedCipher_DeletesCipher( - OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + OrganizationUserType organizationUserType, CipherOrganizationDetails cipherOrgDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipherDetails.OrganizationId = organization.Id; + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetUserByPrincipalAsync(default).ReturnsForAnyArgs(new User { Id = userId }); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); sutProvider.GetDependency() .GetManyUnassignedOrganizationDetailsByOrganizationIdAsync(organization.Id) .Returns(new List { - new() { Id = cipherDetails.Id, OrganizationId = cipherDetails.OrganizationId } + new() { Id = cipherOrgDetails.Id, OrganizationId = cipherOrgDetails.OrganizationId } }); sutProvider.GetDependency() .GetOrganizationAbilityAsync(organization.Id) @@ -297,54 +296,65 @@ public class CiphersControllerTests LimitItemDeletion = true }); - await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + await sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync(Arg.Is(c => c.Id == cipherOrgDetails.Id && c.OrganizationId == cipherOrgDetails.OrganizationId), + userId, + true); } [Theory] [BitAutoData(OrganizationUserType.Owner)] [BitAutoData(OrganizationUserType.Admin)] public async Task DeleteAdmin_WithAdminOrOwner_WithAccessToAllCollectionItems_DeletesCipher( - OrganizationUserType organizationUserType, CipherDetails cipherDetails, Guid userId, + OrganizationUserType organizationUserType, CipherOrganizationDetails cipherOrgDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipherDetails.OrganizationId = organization.Id; + + organization.Type = organizationUserType; + + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = organizationUserType; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherOrgDetails }); sutProvider.GetDependency().GetOrganizationAbilityAsync(organization.Id).Returns(new OrganizationAbility { Id = organization.Id, AllowAdminAccessToAllCollectionItems = true }); - await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + await sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync( + Arg.Is(c => c.Id == cipherOrgDetails.Id && c.OrganizationId == cipherOrgDetails.OrganizationId), + userId, + true); } [Theory] [BitAutoData] public async Task DeleteAdmin_WithCustomUser_WithEditAnyCollectionTrue_DeletesCipher( - CipherDetails cipherDetails, Guid userId, + CipherOrganizationDetails cipherOrgDetails, Guid userId, CurrentContextOrganization organization, SutProvider sutProvider) { - cipherDetails.OrganizationId = organization.Id; + cipherOrgDetails.OrganizationId = organization.Id; organization.Type = OrganizationUserType.Custom; organization.Permissions.EditAnyCollection = true; sutProvider.GetDependency().GetProperUserId(default).ReturnsForAnyArgs(userId); sutProvider.GetDependency().GetOrganization(organization.Id).Returns(organization); - sutProvider.GetDependency().GetByIdAsync(cipherDetails.Id, userId).Returns(cipherDetails); - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherDetails }); + sutProvider.GetDependency().GetOrganizationDetailsByIdAsync(cipherOrgDetails.Id).Returns(cipherOrgDetails); + sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).Returns(new List { cipherOrgDetails }); - await sutProvider.Sut.DeleteAdmin(cipherDetails.Id); + await sutProvider.Sut.DeleteAdmin(cipherOrgDetails.Id); - await sutProvider.GetDependency().Received(1).DeleteAsync(cipherDetails, userId, true); + await sutProvider.GetDependency().Received(1).DeleteAsync( + Arg.Is(c => c.Id == cipherOrgDetails.Id && c.OrganizationId == cipherOrgDetails.OrganizationId), + userId, + true); } [Theory] diff --git a/test/Billing.Test/Services/StripeEventServiceTests.cs b/test/Billing.Test/Services/StripeEventServiceTests.cs index 68aeab2f44..c438ef663c 100644 --- a/test/Billing.Test/Services/StripeEventServiceTests.cs +++ b/test/Billing.Test/Services/StripeEventServiceTests.cs @@ -4,6 +4,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Billing.Caches; using Bit.Core.Repositories; using Bit.Core.Settings; +using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; @@ -28,7 +29,13 @@ public class StripeEventServiceTests _providerRepository = Substitute.For(); _setupIntentCache = Substitute.For(); _stripeFacade = Substitute.For(); - _stripeEventService = new StripeEventService(globalSettings, _organizationRepository, _providerRepository, _setupIntentCache, _stripeFacade); + _stripeEventService = new StripeEventService( + globalSettings, + Substitute.For>(), + _organizationRepository, + _providerRepository, + _setupIntentCache, + _stripeFacade); } #region GetCharge diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 182f09e163..1807050b31 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -1,5 +1,4 @@ -using Bit.Billing.Constants; -using Bit.Billing.Services; +using Bit.Billing.Services; using Bit.Billing.Services.Implementations; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; @@ -11,17 +10,15 @@ using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Test.Billing.Mocks.Plans; -using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; using NSubstitute; using NSubstitute.ReturnsExtensions; -using Quartz; using Stripe; using Xunit; +using static Bit.Core.Billing.Constants.StripeConstants; using Event = Stripe.Event; -using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; -using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; namespace Bit.Billing.Test.Services; @@ -37,10 +34,8 @@ public class SubscriptionUpdatedHandlerTests private readonly IOrganizationEnableCommand _organizationEnableCommand; private readonly IOrganizationDisableCommand _organizationDisableCommand; private readonly IPricingClient _pricingClient; - private readonly IFeatureService _featureService; private readonly IProviderRepository _providerRepository; private readonly IProviderService _providerService; - private readonly IScheduler _scheduler; private readonly IPushNotificationAdapter _pushNotificationAdapter; private readonly SubscriptionUpdatedHandler _sut; @@ -54,19 +49,13 @@ public class SubscriptionUpdatedHandlerTests _userService = Substitute.For(); _providerService = Substitute.For(); _organizationRepository = Substitute.For(); - var schedulerFactory = Substitute.For(); _organizationEnableCommand = Substitute.For(); _organizationDisableCommand = Substitute.For(); _pricingClient = Substitute.For(); - _featureService = Substitute.For(); _providerRepository = Substitute.For(); _providerService = Substitute.For(); - var logger = Substitute.For>(); - _scheduler = Substitute.For(); _pushNotificationAdapter = Substitute.For(); - schedulerFactory.GetScheduler().Returns(_scheduler); - _sut = new SubscriptionUpdatedHandler( _stripeEventService, _stripeEventUtilityService, @@ -75,46 +64,66 @@ public class SubscriptionUpdatedHandlerTests _organizationSponsorshipRenewCommand, _userService, _organizationRepository, - schedulerFactory, _organizationEnableCommand, _organizationDisableCommand, _pricingClient, - _featureService, _providerRepository, _providerService, - logger, _pushNotificationAdapter); } [Fact] - public async Task HandleAsync_UnpaidOrganizationSubscription_DisablesOrganizationAndSchedulesCancellation() + public async Task HandleAsync_UnpaidOrganizationSubscription_DisablesOrganizationAndSetsCancellation() { // Arrange var organizationId = Guid.NewGuid(); var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Active + }; + var subscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, + Status = SubscriptionStatus.Unpaid, Items = new StripeList { Data = [ - new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + } ] }, Metadata = new Dictionary { { "organizationId", organizationId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var parsedEvent = new Event { Data = new EventData() }; + var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 }; + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(organizationId, null, null)); + _organizationRepository.GetByIdAsync(organizationId).Returns(organization); + + var plan = new Enterprise2023Plan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType).Returns(plan); + _pricingClient.ListPlans().Returns(MockPlans.Plans); // Act await _sut.HandleAsync(parsedEvent); @@ -122,14 +131,21 @@ public class SubscriptionUpdatedHandlerTests // Assert await _organizationDisableCommand.Received(1) .DisableAsync(organizationId, currentPeriodEnd); - await _scheduler.Received(1).ScheduleJob( - Arg.Is(j => j.Key.Name == $"cancel-sub-{subscriptionId}"), - Arg.Is(t => t.Key.Name == $"cancel-trigger-{subscriptionId}")); + await _pushNotificationAdapter.Received(1) + .NotifyEnabledChangedAsync(organization); + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAt.HasValue && + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + options.ProrationBehavior == ProrationBehavior.None && + options.CancellationDetails != null && + options.CancellationDetails.Comment != null)); } [Fact] public async Task - HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSchedulesCancellation() + HandleAsync_UnpaidProviderSubscription_WithValidTransition_DisablesProviderAndSetsCancellation() { // Arrange var providerId = Guid.NewGuid(); @@ -138,14 +154,13 @@ public class SubscriptionUpdatedHandlerTests var previousSubscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.Active, - Metadata = new Dictionary { ["providerId"] = providerId.ToString() } + Status = SubscriptionStatus.Active }; var currentSubscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, + Status = SubscriptionStatus.Unpaid, Items = new StripeList { Data = @@ -154,14 +169,12 @@ public class SubscriptionUpdatedHandlerTests ] }, Metadata = new Dictionary { ["providerId"] = providerId.ToString() }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" }, + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle }, TestClock = null }; var parsedEvent = new Event { - Id = "evt_test123", - Type = HandledStripeWebhook.SubscriptionUpdated, Data = new EventData { Object = currentSubscription, @@ -172,8 +185,6 @@ public class SubscriptionUpdatedHandlerTests var provider = new Provider { Id = providerId, Enabled = true }; _stripeEventService.GetSubscription(parsedEvent, true, Arg.Any>()).Returns(currentSubscription); - _stripeEventUtilityService.GetIdsFromMetadata(currentSubscription.Metadata) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository.GetByIdAsync(providerId).Returns(provider); // Act @@ -188,16 +199,25 @@ public class SubscriptionUpdatedHandlerTests subscriptionId, Arg.Is(options => options.CancelAt.HasValue && - options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1))); + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + options.ProrationBehavior == ProrationBehavior.None && + options.CancellationDetails != null && + options.CancellationDetails.Comment != null)); } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WithoutValidTransition_DisablesProviderOnly() + public async Task HandleAsync_UnpaidProviderSubscription_WithoutValidTransition_DoesNotDisableProvider() { // Arrange var providerId = Guid.NewGuid(); const string subscriptionId = "sub_123"; + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Unpaid // No valid transition (already unpaid) + }; + var subscription = new Subscription { Id = subscriptionId, @@ -208,9 +228,9 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } ] }, - Status = StripeSubscriptionStatus.Unpaid, + Status = SubscriptionStatus.Unpaid, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; @@ -219,38 +239,40 @@ public class SubscriptionUpdatedHandlerTests { Data = new EventData { - PreviousAttributes = JObject.FromObject(new - { - status = "unpaid" // No valid transition - }) + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) } }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); - // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); + // Assert - No disable or cancellation since there was no valid status transition + Assert.True(provider.Enabled); + await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WithNoPreviousAttributes_DisablesProviderOnly() + public async Task HandleAsync_UnpaidProviderSubscription_WithNonMatchingPreviousStatus_DoesNotDisableProvider() { // Arrange var providerId = Guid.NewGuid(); const string subscriptionId = "sub_123"; + // Previous status is Canceled, which is not a valid transition source (Trialing/Active/PastDue) + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Canceled + }; + var subscription = new Subscription { Id = subscriptionId, @@ -261,45 +283,56 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } ] }, - Status = StripeSubscriptionStatus.Unpaid, + Status = SubscriptionStatus.Unpaid, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event { Data = new EventData { PreviousAttributes = null } }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); - // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); + // Assert - No disable or cancellation since the previous status (Canceled) is not a valid transition source + Assert.True(provider.Enabled); + await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WithIncompleteExpiredStatus_DisablesProvider() + public async Task HandleAsync_ProviderSubscription_WithIncompleteExpiredStatus_DoesNotDisableProvider() { // Arrange var providerId = Guid.NewGuid(); var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + // Previous status that doesn't trigger enable/disable logic + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Incomplete + }; + var subscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.IncompleteExpired, + Status = SubscriptionStatus.IncompleteExpired, Items = new StripeList { Data = @@ -313,38 +346,48 @@ public class SubscriptionUpdatedHandlerTests var provider = new Provider { Id = providerId, Name = "Test Provider", Enabled = true }; - var parsedEvent = new Event { Data = new EventData() }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId) .Returns(provider); // Act await _sut.HandleAsync(parsedEvent); - // Assert - Assert.False(provider.Enabled); - await _providerService.Received(1).UpdateAsync(provider); + // Assert - IncompleteExpired status is not handled by the new logic + Assert.True(provider.Enabled); + await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } [Fact] - public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_DoesNothing() + public async Task HandleAsync_UnpaidProviderSubscription_WhenProviderNotFound_StillSetsCancellation() { // Arrange var providerId = Guid.NewGuid(); var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Active + }; + var subscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, + Status = SubscriptionStatus.Unpaid, Items = new StripeList { Data = @@ -353,194 +396,216 @@ public class SubscriptionUpdatedHandlerTests ] }, Metadata = new Dictionary { { "providerId", providerId.ToString() } }, - LatestInvoice = new Invoice { BillingReason = "subscription_cycle" } + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var parsedEvent = new Event { Data = new EventData() }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository.GetByIdAsync(providerId) .Returns((Provider)null); // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - Provider not updated (since not found), but cancellation is still set await _providerService.DidNotReceive().UpdateAsync(Arg.Any()); + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAt.HasValue && + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + options.ProrationBehavior == ProrationBehavior.None && + options.CancellationDetails != null && + options.CancellationDetails.Comment != null)); + } + + [Fact] + public async Task HandleAsync_UnpaidUserSubscription_DisablesPremiumAndSetsCancellation() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Active + }; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Unpaid, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + }, + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } + }; + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _userService.Received(1) + .DisablePremiumAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAt.HasValue && + options.CancelAt.Value <= DateTime.UtcNow.AddDays(7).AddMinutes(1) && + options.ProrationBehavior == ProrationBehavior.None && + options.CancellationDetails != null && + options.CancellationDetails.Comment != null)); + } + + [Fact] + public async Task HandleAsync_IncompleteExpiredUserSubscription_OnlyUpdatesExpiration() + { + // Arrange + var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; + var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + // Previous status that doesn't trigger enable/disable logic + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Incomplete + }; + + var subscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.IncompleteExpired, + Metadata = new Dictionary { { "userId", userId.ToString() } }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + ] + } + }; + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert - IncompleteExpired is no longer handled specially, only expiration is updated + await _userService.DidNotReceive().DisablePremiumAsync(Arg.Any(), Arg.Any()); + await _userService.Received(1).UpdatePremiumExpirationAsync(userId, currentPeriodEnd); await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } - [Fact] - public async Task HandleAsync_UnpaidUserSubscription_DisablesPremiumAndCancelsSubscription() - { - // Arrange - var userId = Guid.NewGuid(); - var subscriptionId = "sub_123"; - var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - var subscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Unpaid, - Metadata = new Dictionary { { "userId", userId.ToString() } }, - Items = new StripeList - { - Data = - [ - new SubscriptionItem - { - CurrentPeriodEnd = currentPeriodEnd, - Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } - } - ] - } - }; - - var parsedEvent = new Event { Data = new EventData() }; - - var premiumPlan = new PremiumPlan - { - Name = "Premium", - Available = true, - LegacyYear = null, - Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, - Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } - }; - _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); - - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); - - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, userId, null)); - - _stripeFacade.ListInvoices(Arg.Any()) - .Returns(new StripeList { Data = new List() }); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - await _userService.Received(1) - .DisablePremiumAsync(userId, currentPeriodEnd); - await _stripeFacade.Received(1) - .CancelSubscription(subscriptionId, Arg.Any()); - await _stripeFacade.Received(1) - .ListInvoices(Arg.Is(o => - o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); - } - - [Fact] - public async Task HandleAsync_IncompleteExpiredUserSubscription_DisablesPremiumAndCancelsSubscription() - { - // Arrange - var userId = Guid.NewGuid(); - var subscriptionId = "sub_123"; - var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - var subscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.IncompleteExpired, - Metadata = new Dictionary { { "userId", userId.ToString() } }, - Items = new StripeList - { - Data = - [ - new SubscriptionItem - { - CurrentPeriodEnd = currentPeriodEnd, - Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } - } - ] - } - }; - - var parsedEvent = new Event { Data = new EventData() }; - - var premiumPlan = new PremiumPlan - { - Name = "Premium", - Available = true, - LegacyYear = null, - Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, - Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } - }; - _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); - - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); - - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, userId, null)); - - _stripeFacade.ListInvoices(Arg.Any()) - .Returns(new StripeList { Data = new List() }); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - await _userService.Received(1) - .DisablePremiumAsync(userId, currentPeriodEnd); - await _stripeFacade.Received(1) - .CancelSubscription(subscriptionId, Arg.Any()); - await _stripeFacade.Received(1) - .ListInvoices(Arg.Is(o => - o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); - } - [Fact] public async Task HandleAsync_ActiveOrganizationSubscription_EnablesOrganizationAndUpdatesExpiration() { // Arrange var organizationId = Guid.NewGuid(); + var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Unpaid + }; + var subscription = new Subscription { - Status = StripeSubscriptionStatus.Active, + Id = subscriptionId, + Status = SubscriptionStatus.Active, Items = new StripeList { Data = [ - new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } + new SubscriptionItem + { + CurrentPeriodEnd = currentPeriodEnd, + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + } ] }, - Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; var organization = new Organization { Id = organizationId, PlanType = PlanType.EnterpriseAnnually2023 }; - var parsedEvent = new Event { Data = new EventData() }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(organizationId, null, null)); - _organizationRepository.GetByIdAsync(organizationId) .Returns(organization); - _stripeFacade.ListInvoices(Arg.Any()) - .Returns(new StripeList { Data = [new Invoice { Id = "inv_123" }] }); - var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); // Act await _sut.HandleAsync(parsedEvent); // Assert await _organizationEnableCommand.Received(1) - .EnableAsync(organizationId); + .EnableAsync(organizationId, currentPeriodEnd); await _organizationService.Received(1) .UpdateExpirationDateAsync(organizationId, currentPeriodEnd); await _pushNotificationAdapter.Received(1) .NotifyEnabledChangedAsync(organization); + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAtPeriodEnd == false && + options.ProrationBehavior == ProrationBehavior.None)); } [Fact] @@ -548,10 +613,19 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var userId = Guid.NewGuid(); + var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Unpaid + }; + var subscription = new Subscription { - Status = StripeSubscriptionStatus.Active, + Id = subscriptionId, + Status = SubscriptionStatus.Active, Items = new StripeList { Data = @@ -559,17 +633,22 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } ] }, - Metadata = new Dictionary { { "userId", userId.ToString() } } + Metadata = new Dictionary { { "userId", userId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; - var parsedEvent = new Event { Data = new EventData() }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, userId, null)); - // Act await _sut.HandleAsync(parsedEvent); @@ -578,6 +657,11 @@ public class SubscriptionUpdatedHandlerTests .EnablePremiumAsync(userId, currentPeriodEnd); await _userService.Received(1) .UpdatePremiumExpirationAsync(userId, currentPeriodEnd); + await _stripeFacade.Received(1).UpdateSubscription( + subscriptionId, + Arg.Is(options => + options.CancelAtPeriodEnd == false && + options.ProrationBehavior == ProrationBehavior.None)); } [Fact] @@ -585,10 +669,20 @@ public class SubscriptionUpdatedHandlerTests { // Arrange var organizationId = Guid.NewGuid(); + var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); + + // Use a previous status that won't trigger enable/disable logic + var previousSubscription = new Subscription + { + Id = subscriptionId, + Status = SubscriptionStatus.Active + }; + var subscription = new Subscription { - Status = StripeSubscriptionStatus.Active, + Id = subscriptionId, + Status = SubscriptionStatus.Active, Items = new StripeList { Data = @@ -599,14 +693,18 @@ public class SubscriptionUpdatedHandlerTests Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } }; - var parsedEvent = new Event { Data = new EventData() }; + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) + } + }; _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(organizationId, null, null)); - _stripeEventUtilityService.IsSponsoredSubscription(subscription) .Returns(true); @@ -627,7 +725,7 @@ public class SubscriptionUpdatedHandlerTests var subscription = new Subscription { Id = "sub_123", - Status = StripeSubscriptionStatus.Active, + Status = SubscriptionStatus.Active, CustomerId = "cus_123", Items = new StripeList { @@ -636,7 +734,7 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), - Plan = new Stripe.Plan { Id = "2023-enterprise-org-seat-annually" } + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } } ] }, @@ -654,6 +752,8 @@ public class SubscriptionUpdatedHandlerTests var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); var parsedEvent = new Event { @@ -670,7 +770,7 @@ public class SubscriptionUpdatedHandlerTests { Data = [ - new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-enterprise-seat-annually" } } + new SubscriptionItem { Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } } ] } }) @@ -680,9 +780,6 @@ public class SubscriptionUpdatedHandlerTests _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(organizationId, null, null)); - _organizationRepository.GetByIdAsync(organizationId) .Returns(organization); @@ -693,11 +790,94 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade.Received(1).DeleteCustomerDiscount(subscription.CustomerId); await _stripeFacade.Received(1).DeleteSubscriptionDiscount(subscription.Id); } + [Fact] + public async Task + HandleAsync_WhenUpgradingPlan_AndPreviousPlanHasSecretsManagerTrial_AndCurrentPlanHasSecretsManagerTrial_DoesNotRemovePasswordManagerCoupon() + { + // Arrange + var organizationId = Guid.NewGuid(); + var subscription = new Subscription + { + Id = "sub_123", + Status = SubscriptionStatus.Active, + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + }, + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } + } + ] + }, + Customer = new Customer + { + Balance = 0, + Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } } + }, + Discounts = [new Discount { Coupon = new Coupon { Id = "sm-standalone" } }], + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } + }; + + // Note: The organization plan is still the previous plan because the subscription is updated before the organization is updated + var organization = new Organization { Id = organizationId, PlanType = PlanType.TeamsAnnually2023 }; + + var plan = new Teams2023Plan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType) + .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(new + { + items = new + { + data = new[] + { + new { plan = new { id = "secrets-manager-teams-seat-annually" } }, + } + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-teams-seat-annually" } }, + ] + } + }) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _organizationRepository.GetByIdAsync(organizationId) + .Returns(organization); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.DidNotReceive().DeleteCustomerDiscount(subscription.CustomerId); + await _stripeFacade.DidNotReceive().DeleteSubscriptionDiscount(subscription.Id); + } [Theory] - [MemberData(nameof(GetNonActiveSubscriptions))] + [MemberData(nameof(GetValidTransitionToActiveSubscriptions))] public async Task - HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasNonActive_EnableProviderAndUpdateSubscription( + HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasIncompleteOrUnpaid_EnableProviderAndUpdateSubscription( Subscription previousSubscription) { // Arrange @@ -708,10 +888,6 @@ public class SubscriptionUpdatedHandlerTests .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); - _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -726,9 +902,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); await _providerRepository .Received(1) .GetByIdAsync(providerId); @@ -738,24 +911,23 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade .Received(1) .UpdateSubscription(newSubscription.Id, - Arg.Is(options => options.CancelAtPeriodEnd == false)); + Arg.Is(options => + options.CancelAtPeriodEnd == false && + options.ProrationBehavior == ProrationBehavior.None)); } [Fact] public async Task - HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasCanceled_EnableProvider() + HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasCanceled_DoesNotEnableProvider() { // Arrange - var previousSubscription = new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Canceled }; + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled }; var (providerId, newSubscription, provider, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -763,17 +935,14 @@ public class SubscriptionUpdatedHandlerTests // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - Canceled is not a valid transition source for SubscriptionBecameActive await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); - await _providerRepository.Received(1).GetByIdAsync(providerId); + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .DidNotReceive() + .UpdateAsync(Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -781,19 +950,16 @@ public class SubscriptionUpdatedHandlerTests [Fact] public async Task - HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasAlreadyActive_EnableProvider() + HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasAlreadyActive_DoesNotEnableProvider() { // Arrange - var previousSubscription = new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Active }; + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Active }; var (providerId, newSubscription, provider, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -801,17 +967,14 @@ public class SubscriptionUpdatedHandlerTests // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - Already Active is not a valid transition for SubscriptionBecameActive await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); - await _providerRepository.Received(1).GetByIdAsync(providerId); + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .DidNotReceive() + .UpdateAsync(Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -819,19 +982,16 @@ public class SubscriptionUpdatedHandlerTests [Fact] public async Task - HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasTrailing_EnableProvider() + HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasTrialing_DoesNotEnableProvider() { // Arrange - var previousSubscription = new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Trialing }; + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Trialing }; var (providerId, newSubscription, provider, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -839,17 +999,14 @@ public class SubscriptionUpdatedHandlerTests // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - Trialing is not a valid transition source for SubscriptionBecameActive await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); - await _providerRepository.Received(1).GetByIdAsync(providerId); + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .DidNotReceive() + .UpdateAsync(Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -857,20 +1014,16 @@ public class SubscriptionUpdatedHandlerTests [Fact] public async Task - HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasPastDue_EnableProvider() + HandleAsync_ActiveProviderSubscriptionEvent_AndPreviousSubscriptionStatusWasPastDue_DoesNotEnableProvider() { // Arrange - var previousSubscription = new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.PastDue }; + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.PastDue }; var (providerId, newSubscription, provider, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); - _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -878,19 +1031,14 @@ public class SubscriptionUpdatedHandlerTests // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - PastDue is not a valid transition source for SubscriptionBecameActive await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); - await _providerRepository - .Received(1) - .GetByIdAsync(Arg.Any()); + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .DidNotReceive() + .UpdateAsync(Arg.Any()); await _stripeFacade .DidNotReceiveWithAnyArgs() .UpdateSubscription(Arg.Any()); @@ -900,16 +1048,13 @@ public class SubscriptionUpdatedHandlerTests public async Task HandleAsync_ActiveProviderSubscriptionEvent_AndProviderDoesNotExist_NoChanges() { // Arrange - var previousSubscription = new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Unpaid }; + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Unpaid }; var (providerId, newSubscription, _, parsedEvent) = CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .ReturnsNull(); @@ -921,9 +1066,6 @@ public class SubscriptionUpdatedHandlerTests await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); await _providerRepository .Received(1) .GetByIdAsync(providerId); @@ -936,18 +1078,16 @@ public class SubscriptionUpdatedHandlerTests } [Fact] - public async Task HandleAsync_ActiveProviderSubscriptionEvent_WithNoPreviousAttributes_EnableProvider() + public async Task HandleAsync_ActiveProviderSubscriptionEvent_WithNonMatchingPreviousStatus_DoesNotEnableProvider() { - // Arrange + // Arrange - Using a previous status (Canceled) that doesn't trigger SubscriptionBecameActive + var previousSubscription = new Subscription { Id = "sub_123", Status = SubscriptionStatus.Canceled }; var (providerId, newSubscription, provider, parsedEvent) = - CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(null); + CreateProviderTestInputsForUpdatedActiveSubscriptionStatus(previousSubscription); _stripeEventService .GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(newSubscription); - _stripeEventUtilityService - .GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, null, providerId)); _providerRepository .GetByIdAsync(Arg.Any()) .Returns(provider); @@ -955,19 +1095,14 @@ public class SubscriptionUpdatedHandlerTests // Act await _sut.HandleAsync(parsedEvent); - // Assert + // Assert - Canceled is not a valid transition source, so no enable logic is triggered await _stripeEventService .Received(1) .GetSubscription(parsedEvent, true, Arg.Any>()); - _stripeEventUtilityService - .Received(1) - .GetIdsFromMetadata(newSubscription.Metadata); - await _providerRepository - .Received(1) - .GetByIdAsync(Arg.Any()); + await _providerRepository.DidNotReceive().GetByIdAsync(Arg.Any()); await _providerService - .Received(1) - .UpdateAsync(Arg.Is(p => p.Id == providerId && p.Enabled == true)); + .DidNotReceive() + .UpdateAsync(Arg.Any()); await _stripeFacade .DidNotReceive() .UpdateSubscription(Arg.Any()); @@ -987,8 +1122,9 @@ public class SubscriptionUpdatedHandlerTests new SubscriptionItem { CurrentPeriodEnd = DateTime.UtcNow.AddDays(30) } ] }, - Status = StripeSubscriptionStatus.Active, - Metadata = new Dictionary { { "providerId", providerId.ToString() } } + Status = SubscriptionStatus.Active, + Metadata = new Dictionary { { "providerId", providerId.ToString() } }, + LatestInvoice = new Invoice { BillingReason = BillingReasons.SubscriptionCycle } }; var provider = new Provider { Id = providerId, Enabled = false }; @@ -1005,144 +1141,63 @@ public class SubscriptionUpdatedHandlerTests } [Fact] - public async Task HandleAsync_IncompleteUserSubscriptionWithOpenInvoice_CancelsSubscriptionAndDisablesPremium() + public async Task HandleAsync_IncompleteUserSubscription_OnlyUpdatesExpiration() { // Arrange var userId = Guid.NewGuid(); var subscriptionId = "sub_123"; var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - var openInvoice = new Invoice + + // Previous status that doesn't trigger enable/disable logic (already was incomplete) + var previousSubscription = new Subscription { - Id = "inv_123", - Status = StripeInvoiceStatus.Open + Id = subscriptionId, + Status = SubscriptionStatus.Incomplete }; + var subscription = new Subscription { Id = subscriptionId, - Status = StripeSubscriptionStatus.Incomplete, + Status = SubscriptionStatus.Incomplete, Metadata = new Dictionary { { "userId", userId.ToString() } }, - LatestInvoice = openInvoice, + LatestInvoice = new Invoice { Status = "open" }, Items = new StripeList { Data = [ - new SubscriptionItem - { - CurrentPeriodEnd = currentPeriodEnd, - Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } - } + new SubscriptionItem { CurrentPeriodEnd = currentPeriodEnd } ] } }; - var parsedEvent = new Event { Data = new EventData() }; - - var premiumPlan = new PremiumPlan + var parsedEvent = new Event { - Name = "Premium", - Available = true, - LegacyYear = null, - Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, - Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } - }; - _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); - - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) - .Returns(subscription); - - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, userId, null)); - - _stripeFacade.ListInvoices(Arg.Any()) - .Returns(new StripeList { Data = new List { openInvoice } }); - - // Act - await _sut.HandleAsync(parsedEvent); - - // Assert - await _userService.Received(1) - .DisablePremiumAsync(userId, currentPeriodEnd); - await _stripeFacade.Received(1) - .CancelSubscription(subscriptionId, Arg.Any()); - await _stripeFacade.Received(1) - .ListInvoices(Arg.Is(o => - o.Status == StripeInvoiceStatus.Open && o.Subscription == subscriptionId)); - await _stripeFacade.Received(1) - .VoidInvoice(openInvoice.Id); - } - - [Fact] - public async Task HandleAsync_IncompleteUserSubscriptionWithoutOpenInvoice_DoesNotCancelSubscription() - { - // Arrange - var userId = Guid.NewGuid(); - var subscriptionId = "sub_123"; - var currentPeriodEnd = DateTime.UtcNow.AddDays(30); - var paidInvoice = new Invoice - { - Id = "inv_123", - Status = StripeInvoiceStatus.Paid - }; - var subscription = new Subscription - { - Id = subscriptionId, - Status = StripeSubscriptionStatus.Incomplete, - Metadata = new Dictionary { { "userId", userId.ToString() } }, - LatestInvoice = paidInvoice, - Items = new StripeList + Data = new EventData { - Data = - [ - new SubscriptionItem - { - CurrentPeriodEnd = currentPeriodEnd, - Price = new Price { Id = IStripeEventUtilityService.PremiumPlanId } - } - ] + Object = subscription, + PreviousAttributes = JObject.FromObject(previousSubscription) } }; - var parsedEvent = new Event { Data = new EventData() }; - - var premiumPlan = new PremiumPlan - { - Name = "Premium", - Available = true, - LegacyYear = null, - Seat = new PremiumPurchasable { Price = 10M, StripePriceId = IStripeEventUtilityService.PremiumPlanId }, - Storage = new PremiumPurchasable { Price = 4M, StripePriceId = "storage-plan-personal" } - }; - _pricingClient.ListPremiumPlans().Returns(new List { premiumPlan }); - _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) .Returns(subscription); - _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) - .Returns(Tuple.Create(null, userId, null)); - // Act await _sut.HandleAsync(parsedEvent); - // Assert - await _userService.DidNotReceive() - .DisablePremiumAsync(Arg.Any(), Arg.Any()); - await _stripeFacade.DidNotReceive() - .CancelSubscription(Arg.Any(), Arg.Any()); - await _stripeFacade.DidNotReceive() - .ListInvoices(Arg.Any()); + // Assert - Incomplete status is no longer handled specially, only expiration is updated + await _userService.DidNotReceive().DisablePremiumAsync(Arg.Any(), Arg.Any()); + await _userService.Received(1).UpdatePremiumExpirationAsync(userId, currentPeriodEnd); + await _stripeFacade.DidNotReceive().UpdateSubscription(Arg.Any(), Arg.Any()); } - public static IEnumerable GetNonActiveSubscriptions() + public static IEnumerable GetValidTransitionToActiveSubscriptions() { + // Only Incomplete and Unpaid are valid previous statuses for SubscriptionBecameActive return new List { - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Unpaid } }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Incomplete } }, - new object[] - { - new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.IncompleteExpired } - }, - new object[] { new Subscription { Id = "sub_123", Status = StripeSubscriptionStatus.Paused } } + new object[] { new Subscription { Id = "sub_123", Status = SubscriptionStatus.Unpaid } }, + new object[] { new Subscription { Id = "sub_123", Status = SubscriptionStatus.Incomplete } } }; } } diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs index 3b133c7d37..82d6c8acfd 100644 --- a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -280,7 +280,7 @@ public class UpcomingInvoiceHandlerTests email.ToEmails.Contains("user@example.com") && email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && - email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountedAnnualRenewalPrice == discountedPrice.ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == $"{coupon.PercentOff}%" )); } @@ -2436,7 +2436,7 @@ public class UpcomingInvoiceHandlerTests email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == "30%" && - email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US")) + email.View.DiscountedAnnualRenewalPrice == expectedDiscountedPrice.ToString("C", new CultureInfo("en-US")) )); await _mailService.DidNotReceive().SendInvoiceUpcoming( diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs index 09b112c43c..01ffb86a7d 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyFixtures.cs @@ -3,6 +3,7 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; namespace Bit.Core.Test.AdminConsole.AutoFixture; @@ -10,19 +11,30 @@ internal class PolicyCustomization : ICustomization { public PolicyType Type { get; set; } public bool Enabled { get; set; } + public string? Data { get; set; } - public PolicyCustomization(PolicyType type, bool enabled) + public PolicyCustomization(PolicyType type, bool enabled, string? data) { Type = type; Enabled = enabled; + Data = data; } public void Customize(IFixture fixture) { + var orgId = Guid.NewGuid(); + fixture.Customize(composer => composer - .With(o => o.OrganizationId, Guid.NewGuid()) + .With(o => o.OrganizationId, orgId) .With(o => o.Type, Type) - .With(o => o.Enabled, Enabled)); + .With(o => o.Enabled, Enabled) + .With(o => o.Data, Data)); + + fixture.Customize(composer => composer + .With(o => o.OrganizationId, orgId) + .With(o => o.Type, Type) + .With(o => o.Enabled, Enabled) + .With(o => o.Data, Data)); } } @@ -30,15 +42,17 @@ public class PolicyAttribute : CustomizeAttribute { private readonly PolicyType _type; private readonly bool _enabled; + private readonly string? _data; - public PolicyAttribute(PolicyType type, bool enabled = true) + public PolicyAttribute(PolicyType type, bool enabled = true, string? data = null) { _type = type; _enabled = enabled; + _data = data; } public override ICustomization GetCustomization(ParameterInfo parameter) { - return new PolicyCustomization(_type, _enabled); + return new PolicyCustomization(_type, _enabled, _data); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs index 88025301b6..3095907a22 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/AccountRecovery/AdminRecoverAccountCommandTests.cs @@ -1,14 +1,16 @@ using AutoFixture; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.AccountRecovery; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; @@ -29,11 +31,12 @@ public class AdminRecoverAccountCommandTests Organization organization, OrganizationUser organizationUser, User user, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, SutProvider sutProvider) { // Arrange SetupValidOrganization(sutProvider, organization); - SetupValidPolicy(sutProvider, organization); + SetupValidPolicy(sutProvider, organization, policy); SetupValidOrganizationUser(organizationUser, organization.Id); SetupValidUser(sutProvider, user, organizationUser); SetupSuccessfulPasswordUpdate(sutProvider, user, newMasterPassword); @@ -87,25 +90,18 @@ public class AdminRecoverAccountCommandTests Assert.Equal("Organization does not allow password reset.", exception.Message); } - public static IEnumerable InvalidPolicies => new object[][] - { - [new Policy { Type = PolicyType.ResetPassword, Enabled = false }], [null] - }; - [Theory] - [BitMemberAutoData(nameof(InvalidPolicies))] + [BitAutoData] public async Task RecoverAccountAsync_InvalidPolicy_ThrowsBadRequest( - Policy resetPasswordPolicy, string newMasterPassword, string key, Organization organization, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { // Arrange SetupValidOrganization(sutProvider, organization); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) - .Returns(resetPasswordPolicy); + SetupValidPolicy(sutProvider, organization, policy); // Act & Assert var exception = await Assert.ThrowsAsync(() => @@ -171,11 +167,12 @@ public class AdminRecoverAccountCommandTests Organization organization, string newMasterPassword, string key, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, SutProvider sutProvider) { // Arrange SetupValidOrganization(sutProvider, organization); - SetupValidPolicy(sutProvider, organization); + SetupValidPolicy(sutProvider, organization, policy); // Act & Assert var exception = await Assert.ThrowsAsync(() => @@ -190,11 +187,12 @@ public class AdminRecoverAccountCommandTests string key, Organization organization, OrganizationUser organizationUser, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, SutProvider sutProvider) { // Arrange SetupValidOrganization(sutProvider, organization); - SetupValidPolicy(sutProvider, organization); + SetupValidPolicy(sutProvider, organization, policy); SetupValidOrganizationUser(organizationUser, organization.Id); sutProvider.GetDependency() .GetUserByIdAsync(organizationUser.UserId!.Value) @@ -213,11 +211,12 @@ public class AdminRecoverAccountCommandTests Organization organization, OrganizationUser organizationUser, User user, + [Policy(PolicyType.ResetPassword, true)] PolicyStatus policy, SutProvider sutProvider) { // Arrange SetupValidOrganization(sutProvider, organization); - SetupValidPolicy(sutProvider, organization); + SetupValidPolicy(sutProvider, organization, policy); SetupValidOrganizationUser(organizationUser, organization.Id); user.UsesKeyConnector = true; sutProvider.GetDependency() @@ -238,11 +237,10 @@ public class AdminRecoverAccountCommandTests .Returns(organization); } - private static void SetupValidPolicy(SutProvider sutProvider, Organization organization) + private static void SetupValidPolicy(SutProvider sutProvider, Organization organization, PolicyStatus policy) { - var policy = new Policy { Type = PolicyType.ResetPassword, Enabled = true }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.ResetPassword) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.ResetPassword) .Returns(policy); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs index ef4c2c941e..730489a9fc 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs @@ -282,6 +282,7 @@ public class VerifyOrganizationDomainCommandTests await sutProvider.GetDependency().Received().SendClaimedDomainUserEmailAsync( Arg.Is(x => x.EmailList.Count(e => e.EndsWith(domain.DomainName)) == mockedUsers.Count && - x.Organization.Id == organization.Id)); + x.Organization.Id == organization.Id && + x.DomainName == domain.DomainName)); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs index c3fb52ecbe..50e40b9803 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmOrganizationUsersValidatorTests.cs @@ -7,7 +7,6 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.DeleteClaimed using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Enforcement.AutoConfirm; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Billing.Enums; using Bit.Core.Entities; @@ -120,7 +119,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true, planType: PlanType.EnterpriseAnnually)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, User user, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = user.Id; @@ -137,8 +136,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -280,7 +279,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, Guid userId, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = userId; @@ -303,8 +302,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests PolicyType = PolicyType.TwoFactorAuthentication }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -334,7 +333,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, User user, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = user.Id; @@ -351,8 +350,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -389,7 +388,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, User user, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = user.Id; @@ -406,8 +405,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -448,7 +447,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, User user, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = user.Id; @@ -465,8 +464,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -501,7 +500,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests SutProvider sutProvider, Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, - Guid userId) + Guid userId, + [Policy(PolicyType.AutomaticUserConfirmation, false)] PolicyStatus policy) { // Arrange organizationUser.UserId = userId; @@ -518,9 +518,9 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) - .Returns((Policy)null); + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + .Returns(policy); sutProvider.GetDependency() .TwoFactorIsEnabledAsync(Arg.Any>()) @@ -545,7 +545,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: false)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, Guid userId, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = userId; @@ -562,8 +562,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() @@ -589,7 +589,7 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests [Organization(useAutomaticUserConfirmation: true)] Organization organization, [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser organizationUser, User user, - [Policy(PolicyType.AutomaticUserConfirmation)] Policy autoConfirmPolicy) + [Policy(PolicyType.AutomaticUserConfirmation)] PolicyStatus autoConfirmPolicy) { // Arrange organizationUser.UserId = user.Id; @@ -606,8 +606,8 @@ public class AutomaticallyConfirmOrganizationUsersValidatorTests Key = "test-key" }; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(organization.Id, PolicyType.AutomaticUserConfirmation) + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.AutomaticUserConfirmation) .Returns(autoConfirmPolicy); sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs index 180750a9d0..252fb89c87 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs @@ -10,7 +10,6 @@ using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -204,14 +203,10 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.OrganizationId == organization.Id && - c.Name == defaultCollectionName && - c.Type == CollectionType.DefaultUserCollection), - Arg.Is>(groups => groups == null), - Arg.Is>(access => - access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == organizationUser.Id), + defaultCollectionName); } [Theory] @@ -253,9 +248,7 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .DidNotReceive() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory] @@ -291,9 +284,7 @@ public class AutomaticallyConfirmUsersCommandTests var collectionException = new Exception("Collection creation failed"); sutProvider.GetDependency() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()) + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()) .ThrowsAsync(collectionException); // Act diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index 65359b8304..6643f26eb5 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -13,7 +13,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Platform.Push; using Bit.Core.Repositories; @@ -493,15 +492,10 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.Name == collectionName && - c.OrganizationId == organization.Id && - c.Type == CollectionType.DefaultUserCollection), - Arg.Any>(), - Arg.Is>(cu => - cu.Single().Id == orgUser.Id && - cu.Single().Manage)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == orgUser.Id), + collectionName); } [Theory, BitAutoData] @@ -522,7 +516,7 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -539,24 +533,15 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - var policyDetails = new PolicyDetails - { - OrganizationId = org.Id, - OrganizationUserId = orgUser.Id, - IsProvider = false, - OrganizationUserStatus = orgUser.Status, - OrganizationUserType = orgUser.Type, - PolicyType = PolicyType.OrganizationDataOwnership - }; sutProvider.GetDependency() .GetAsync(orgUser.UserId!.Value) - .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails])); + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [])); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommandTests.cs index 23c1a32c03..ddede2d191 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/InviteUsers/SendOrganizationInvitesCommandTests.cs @@ -1,7 +1,9 @@ using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.InviteUsers.Models; -using Bit.Core.AdminConsole.Repositories; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Repositories; @@ -9,6 +11,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Models.Mail; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Core.Tokens; using Bit.Test.Common.AutoFixture; @@ -31,6 +34,7 @@ public class SendOrganizationInvitesCommandTests Organization organization, SsoConfig ssoConfig, OrganizationUser invite, + [Policy(PolicyType.RequireSso, false)] PolicyStatus policy, SutProvider sutProvider) { // Setup FakeDataProtectorTokenFactory for creating new tokens - this must come first in order to avoid resetting mocks @@ -45,7 +49,9 @@ public class SendOrganizationInvitesCommandTests sutProvider.GetDependency().GetByOrganizationIdAsync(organization.Id).Returns(ssoConfig); // Return null policy to mimic new org that's never turned on the require sso policy - sutProvider.GetDependency().GetManyByOrganizationIdAsync(organization.Id).ReturnsNull(); + sutProvider.GetDependency() + .RunAsync(organization.Id, PolicyType.RequireSso) + .Returns(policy); // Mock tokenable factory to return a token that expires in 5 days sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs index 4fa5e92abe..29c996cee9 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs @@ -37,7 +37,7 @@ public class RestoreOrganizationUserCommandTests Sponsored = 0, Users = 1 }); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await sutProvider.GetDependency() .Received(1) @@ -81,7 +81,7 @@ public class RestoreOrganizationUserCommandTests RestoreUser_Setup(organization, owner, organizationUser, sutProvider); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("you cannot restore yourself", exception.Message.ToLowerInvariant()); @@ -107,7 +107,7 @@ public class RestoreOrganizationUserCommandTests RestoreUser_Setup(organization, restoringUser, organizationUser, sutProvider); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, restoringUser.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, restoringUser.Id, null)); Assert.Contains("only owners can restore other owners", exception.Message.ToLowerInvariant()); @@ -133,7 +133,7 @@ public class RestoreOrganizationUserCommandTests RestoreUser_Setup(organization, owner, organizationUser, sutProvider); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("already active", exception.Message.ToLowerInvariant()); @@ -172,7 +172,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com belongs to an organization that doesn't allow them to join multiple organizations", exception.Message.ToLowerInvariant()); @@ -216,7 +216,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant()); @@ -272,7 +272,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant()); @@ -309,7 +309,7 @@ public class RestoreOrganizationUserCommandTests Sponsored = 0, Users = 1 }); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await sutProvider.GetDependency() .Received(1) @@ -349,7 +349,7 @@ public class RestoreOrganizationUserCommandTests } ])); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await sutProvider.GetDependency() .Received(1) @@ -395,7 +395,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com is not compliant with the single organization policy", exception.Message.ToLowerInvariant()); @@ -447,7 +447,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com is not compliant with the single organization and two-step login policy", exception.Message.ToLowerInvariant()); @@ -509,7 +509,7 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Contains("test@bitwarden.com is not compliant with the single organization and two-step login policy", exception.Message.ToLowerInvariant()); @@ -548,7 +548,7 @@ public class RestoreOrganizationUserCommandTests .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await sutProvider.GetDependency() .Received(1) @@ -599,7 +599,7 @@ public class RestoreOrganizationUserCommandTests .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null)); Assert.Equal("User is an owner/admin of another free organization. Please have them upgrade to a paid plan to restore their account.", exception.Message); } @@ -651,7 +651,7 @@ public class RestoreOrganizationUserCommandTests .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await organizationUserRepository .Received(1) @@ -707,7 +707,7 @@ public class RestoreOrganizationUserCommandTests .TwoFactorIsEnabledAsync(Arg.Is>(i => i.Contains(organizationUser.UserId.Value))) .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> { (organizationUser.UserId.Value, true) }); - await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); await organizationUserRepository .Received(1) @@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests Arg.Is(x => x != OrganizationUserStatusType.Revoked)); } + [Theory, BitAutoData] + public async Task RestoreUser_InvitedUserInFreeOrganization_Success( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + organization.PlanType = PlanType.Free; + organizationUser.UserId = null; + organizationUser.Key = null; + organizationUser.Status = OrganizationUserStatusType.Revoked; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 1 + }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, ""); + + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .PushSyncOrgKeysAsync(Arg.Any()); + } + [Theory, BitAutoData] public async Task RestoreUsers_Success(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, @@ -749,7 +782,7 @@ public class RestoreOrganizationUserCommandTests }); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, new[] { orgUser1.Id, orgUser2.Id }, owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, new[] { orgUser1.Id, orgUser2.Id }, owner.Id, userService, null); // Assert Assert.Equal(2, result.Count); @@ -810,7 +843,7 @@ public class RestoreOrganizationUserCommandTests }); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null); // Assert Assert.Equal(3, result.Count); @@ -881,7 +914,7 @@ public class RestoreOrganizationUserCommandTests }); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null); // Assert Assert.Equal(3, result.Count); @@ -959,7 +992,7 @@ public class RestoreOrganizationUserCommandTests }); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id, orgUser2.Id, orgUser3.Id], owner.Id, userService, null); // Assert Assert.Equal(3, result.Count); @@ -1023,7 +1056,7 @@ public class RestoreOrganizationUserCommandTests }); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService, null); // Assert Assert.Single(result); @@ -1074,7 +1107,7 @@ public class RestoreOrganizationUserCommandTests .Returns([new OrganizationUserPolicyDetails { OrganizationId = organization.Id, PolicyType = PolicyType.TwoFactorAuthentication }]); // Act - var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService); + var result = await sutProvider.Sut.RestoreUsersAsync(organization.Id, [orgUser1.Id], owner.Id, userService, null); Assert.Single(result); Assert.Equal(string.Empty, result[0].Item2); @@ -1105,5 +1138,408 @@ public class RestoreOrganizationUserCommandTests sutProvider.GetDependency().OrganizationOwner(organization.Id).Returns(requestingOrganizationUser != null && requestingOrganizationUser.Type is OrganizationUserType.Owner); sutProvider.GetDependency().ManageUsers(organization.Id).Returns(requestingOrganizationUser != null && (requestingOrganizationUser.Type is OrganizationUserType.Owner or OrganizationUserType.Admin)); + + // Setup default disabled OrganizationDataOwnershipPolicyRequirement for any user + sutProvider.GetDependency() + .GetAsync(Arg.Any()) + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [])); } + + private static void SetupOrganizationDataOwnershipPolicy( + SutProvider sutProvider, + Guid userId, + Guid organizationId, + OrganizationUserStatusType orgUserStatus, + bool policyEnabled) + { + var policyDetails = policyEnabled + ? new List + { + new() + { + OrganizationId = organizationId, + OrganizationUserId = Guid.NewGuid(), + OrganizationUserStatus = orgUserStatus, + PolicyType = PolicyType.OrganizationDataOwnership + } + } + : new List(); + + var policyRequirement = new OrganizationDataOwnershipPolicyRequirement( + policyEnabled ? OrganizationDataOwnershipState.Enabled : OrganizationDataOwnershipState.Disabled, + policyDetails); + + sutProvider.GetDependency() + .GetAsync(userId) + .Returns(policyRequirement); + } + + #region Single User Restore - Default Collection Tests + + [Theory, BitAutoData] + public async Task RestoreUser_WithDataOwnershipPolicyEnabled_AndConfirmedUser_CreatesDefaultCollection( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + organizationUser.Email = null; // This causes user to restore to Confirmed status + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + SetupOrganizationDataOwnershipPolicy( + sutProvider, + organizationUser.UserId!.Value, + organization.Id, + OrganizationUserStatusType.Revoked, + policyEnabled: true); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName); + + // Assert + await sutProvider.GetDependency() + .Received(1) + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == organizationUser.Id), + defaultCollectionName); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WithDataOwnershipPolicyDisabled_DoesNotCreateDefaultCollection( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + organizationUser.Email = null; // This causes user to restore to Confirmed status + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + SetupOrganizationDataOwnershipPolicy( + sutProvider, + organizationUser.UserId!.Value, + organization.Id, + OrganizationUserStatusType.Revoked, + policyEnabled: false); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName); + + // Assert + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WithNullDefaultCollectionName_DoesNotCreateDefaultCollection( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + organizationUser.Email = null; // This causes user to restore to Confirmed status + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + SetupOrganizationDataOwnershipPolicy( + sutProvider, + organizationUser.UserId!.Value, + organization.Id, + OrganizationUserStatusType.Revoked, + policyEnabled: true); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, null); + + // Assert + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + [Theory] + [BitAutoData("")] + [BitAutoData(" ")] + public async Task RestoreUser_WithEmptyOrWhitespaceDefaultCollectionName_DoesNotCreateDefaultCollection( + string defaultCollectionName, + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + organizationUser.Email = null; // This causes user to restore to Confirmed status + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + SetupOrganizationDataOwnershipPolicy( + sutProvider, + organizationUser.UserId!.Value, + organization.Id, + OrganizationUserStatusType.Revoked, + policyEnabled: true); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName); + + // Assert + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RestoreUser_UserRestoredToInvitedStatus_DoesNotCreateDefaultCollection( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.EnterpriseAnnually; // Non-Free plan to avoid ownership check requiring UserId + organizationUser.Email = "test@example.com"; // Non-null email means user restores to Invited status + organizationUser.UserId = null; // User not linked to account yet + organizationUser.Key = null; + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName); + + // Assert - User was restored to Invited status, so no collection should be created + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WithNoUserId_DoesNotCreateDefaultCollection( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + organization.PlanType = PlanType.EnterpriseAnnually; // Non-Free plan to avoid ownership check requiring UserId + organizationUser.UserId = null; // No linked user account + organizationUser.Email = "test@example.com"; + organizationUser.Key = null; + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + // Act + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id, defaultCollectionName); + + // Assert + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + #endregion + + #region Bulk User Restore - Default Collection Tests + + [Theory, BitAutoData] + public async Task RestoreUsers_Bulk_WithDataOwnershipPolicy_CreatesCollectionsForEligibleUsers( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + // orgUser1: Will restore to Confirmed (Email = null) + orgUser1.Email = null; + orgUser1.OrganizationId = organization.Id; + + // orgUser2: Will restore to Invited (Email not null) + orgUser2.Email = "test@example.com"; + orgUser2.UserId = null; + orgUser2.Key = null; + orgUser2.OrganizationId = organization.Id; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id))) + .Returns([orgUser1, orgUser2]); + + // Setup bulk policy query - returns org user IDs with policy enabled + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns([orgUser1.Id]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> + { + (orgUser1.UserId!.Value, true) + }); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync( + organization.Id, + [orgUser1.Id, orgUser2.Id], + owner.Id, + userService, + defaultCollectionName); + + // Assert - Only orgUser1 should have a collection created (Confirmed with policy enabled) + await sutProvider.GetDependency() + .Received(1) + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == orgUser1.Id), + defaultCollectionName); + } + + [Theory, BitAutoData] + public async Task RestoreUsers_Bulk_WithMixedPolicyStates_OnlyCreatesForEnabledPolicy( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2, + string defaultCollectionName, + SutProvider sutProvider) + { + // Arrange + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + // Both users will restore to Confirmed + orgUser1.Email = null; + orgUser1.OrganizationId = organization.Id; + orgUser2.Email = null; + orgUser2.OrganizationId = organization.Id; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id))) + .Returns([orgUser1, orgUser2]); + + // Setup bulk policy query - only orgUser1 has policy enabled + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns([orgUser1.Id]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> + { + (orgUser1.UserId!.Value, true), + (orgUser2.UserId!.Value, true) + }); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync( + organization.Id, + [orgUser1.Id, orgUser2.Id], + owner.Id, + userService, + defaultCollectionName); + + // Assert - Only orgUser1 should have a collection created (policy enabled) + await sutProvider.GetDependency() + .Received(1) + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == orgUser1.Id), + defaultCollectionName); + } + + [Theory, BitAutoData] + public async Task RestoreUsers_Bulk_WithNullCollectionName_DoesNotCreateAnyCollections( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser1, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser orgUser2, + SutProvider sutProvider) + { + // Arrange + RestoreUser_Setup(organization, owner, orgUser1, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var userService = Substitute.For(); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.DefaultUserCollectionRestore) + .Returns(true); + + // Both users will restore to Confirmed + orgUser1.Email = null; + orgUser1.OrganizationId = organization.Id; + orgUser2.Email = null; + orgUser2.OrganizationId = organization.Id; + + organizationUserRepository + .GetManyAsync(Arg.Is>(ids => ids.Contains(orgUser1.Id) && ids.Contains(orgUser2.Id))) + .Returns([orgUser1, orgUser2]); + + // Setup bulk policy query - both users have policy enabled + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns([orgUser1.Id, orgUser2.Id]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(Guid userId, bool twoFactorIsEnabled)> + { + (orgUser1.UserId!.Value, true), + (orgUser2.UserId!.Value, true) + }); + + // Act + var result = await sutProvider.Sut.RestoreUsersAsync( + organization.Id, + [orgUser1.Id, orgUser2.Id], + owner.Id, + userService, + null); // Null collection name + + // Assert - No collections should be created + await sutProvider.GetDependency() + .DidNotReceive() + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + } + + #endregion } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs index 93cbde89ec..dd2f1d76e8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs @@ -38,7 +38,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -60,7 +60,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -86,7 +86,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceive() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( Arg.Any(), Arg.Any>(), Arg.Any()); @@ -172,10 +172,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -210,7 +210,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } private static IPolicyRepository ArrangePolicyRepository(IEnumerable policyDetails) @@ -251,7 +251,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -273,7 +273,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -299,7 +299,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( default, default, default); @@ -336,10 +336,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -367,6 +367,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } } diff --git a/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs index 43725d23e0..dcc4ceb246 100644 --- a/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs +++ b/test/Core.Test/AdminConsole/Utilities/PolicyDataValidatorTests.cs @@ -19,12 +19,17 @@ public class PolicyDataValidatorTests [Fact] public void ValidateAndSerialize_ValidData_ReturnsSerializedJson() { - var data = new Dictionary { { "minLength", 12 } }; + var data = new Dictionary + { + { "minLength", 12 }, + { "minComplexity", 4 } + }; var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); Assert.NotNull(result); Assert.Contains("\"minLength\":12", result); + Assert.Contains("\"minComplexity\":4", result); } [Fact] @@ -56,4 +61,122 @@ public class PolicyDataValidatorTests Assert.IsType(result); } + + [Fact] + public void ValidateAndSerialize_ExcessiveMinLength_ThrowsBadRequestException() + { + var data = new Dictionary { { "minLength", 129 } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + } + + [Fact] + public void ValidateAndSerialize_ExcessiveMinComplexity_ThrowsBadRequestException() + { + var data = new Dictionary { { "minComplexity", 5 } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + } + + [Fact] + public void ValidateAndSerialize_MinLengthAtMinimum_Succeeds() + { + var data = new Dictionary { { "minLength", 12 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minLength\":12", result); + } + + [Fact] + public void ValidateAndSerialize_MinLengthAtMaximum_Succeeds() + { + var data = new Dictionary { { "minLength", 128 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minLength\":128", result); + } + + [Fact] + public void ValidateAndSerialize_MinLengthBelowMinimum_ThrowsBadRequestException() + { + var data = new Dictionary { { "minLength", 11 } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + } + + [Fact] + public void ValidateAndSerialize_MinComplexityAtMinimum_Succeeds() + { + var data = new Dictionary { { "minComplexity", 0 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minComplexity\":0", result); + } + + [Fact] + public void ValidateAndSerialize_MinComplexityAtMaximum_Succeeds() + { + var data = new Dictionary { { "minComplexity", 4 } }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minComplexity\":4", result); + } + + [Fact] + public void ValidateAndSerialize_MinComplexityBelowMinimum_ThrowsBadRequestException() + { + var data = new Dictionary { { "minComplexity", -1 } }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + } + + [Fact] + public void ValidateAndSerialize_NullMinLength_Succeeds() + { + var data = new Dictionary + { + { "minComplexity", 2 } + // minLength is omitted, should be null + }; + + var result = PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword); + + Assert.NotNull(result); + Assert.Contains("\"minComplexity\":2", result); + } + + [Fact] + public void ValidateAndSerialize_MultipleInvalidFields_ThrowsBadRequestException() + { + var data = new Dictionary + { + { "minLength", 200 }, + { "minComplexity", 10 } + }; + + var exception = Assert.Throws(() => + PolicyDataValidator.ValidateAndSerialize(data, PolicyType.MasterPassword)); + + Assert.Contains("Invalid data for MasterPassword policy", exception.Message); + } } diff --git a/test/Core.Test/Auth/AutoFixture/RegisterFinishRequestModelFixtures.cs b/test/Core.Test/Auth/AutoFixture/RegisterFinishRequestModelFixtures.cs index a751a16f31..22fca7ab59 100644 --- a/test/Core.Test/Auth/AutoFixture/RegisterFinishRequestModelFixtures.cs +++ b/test/Core.Test/Auth/AutoFixture/RegisterFinishRequestModelFixtures.cs @@ -29,7 +29,9 @@ internal class RegisterFinishRequestModelCustomization : ICustomization .With(o => o.OrgInviteToken, OrgInviteToken) .With(o => o.OrgSponsoredFreeFamilyPlanToken, OrgSponsoredFreeFamilyPlanToken) .With(o => o.AcceptEmergencyAccessInviteToken, AcceptEmergencyAccessInviteToken) - .With(o => o.ProviderInviteToken, ProviderInviteToken)); + .With(o => o.ProviderInviteToken, ProviderInviteToken) + .Without(o => o.MasterPasswordAuthentication) + .Without(o => o.MasterPasswordUnlock)); } } diff --git a/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs b/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs index 588ca878fc..3c099ce962 100644 --- a/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs +++ b/test/Core.Test/Auth/Models/Api/Request/Accounts/RegisterFinishRequestModelTests.cs @@ -1,5 +1,6 @@ using Bit.Core.Auth.Models.Api.Request.Accounts; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Models.Api.Request; using Bit.Test.Common.AutoFixture.Attributes; using Xunit; @@ -7,6 +8,17 @@ namespace Bit.Core.Test.Auth.Models.Api.Request.Accounts; public class RegisterFinishRequestModelTests { + private static List Validate(RegisterFinishRequestModel model) + { + var results = new List(); + System.ComponentModel.DataAnnotations.Validator.TryValidateObject( + model, + new System.ComponentModel.DataAnnotations.ValidationContext(model), + results, + true); + return results; + } + [Theory] [BitAutoData] public void GetTokenType_Returns_EmailVerification(string email, string masterPasswordHash, @@ -170,4 +182,175 @@ public class RegisterFinishRequestModelTests Assert.Equal(userAsymmetricKeys.PublicKey, result.PublicKey); Assert.Equal(userAsymmetricKeys.EncryptedPrivateKey, result.PrivateKey); } + + [Fact] + public void Validate_WhenBothAuthAndRootHashProvidedButNotEqual_ReturnsMismatchError() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + MasterPasswordHash = "root-hash", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + // Provide both unlock and authentication with valid KDF so only the mismatch rule fires + MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterKeyWrappedUserKey = "wrapped", + Salt = "salt" + }, + MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterPasswordAuthenticationHash = "auth-hash", // different than root + Salt = "salt" + }, + // Provide any valid token so we don't fail token validation + EmailVerificationToken = "token" + }; + + var results = Validate(model); + + Assert.Contains(results, r => + r.ErrorMessage == $"{nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash)} and root level {nameof(RegisterFinishRequestModel.MasterPasswordHash)} provided and are not equal. Only provide one."); + } + + [Fact] + public void Validate_WhenAuthProvidedButUnlockMissing_ReturnsUnlockMissingError() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterPasswordAuthenticationHash = "auth-hash", + Salt = "salt" + }, + EmailVerificationToken = "token" + }; + + var results = Validate(model); + + Assert.Contains(results, r => r.ErrorMessage == "MasterPasswordUnlock not found on RequestModel"); + } + + [Fact] + public void Validate_WhenUnlockProvidedButAuthMissing_ReturnsAuthMissingError() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterKeyWrappedUserKey = "wrapped", + Salt = "salt" + }, + EmailVerificationToken = "token" + }; + + var results = Validate(model); + + Assert.Contains(results, r => r.ErrorMessage == "MasterPasswordAuthentication not found on RequestModel"); + } + + [Fact] + public void Validate_WhenNeitherAuthNorUnlock_AndRootKdfMissing_ReturnsBothRootKdfErrors() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + // No MasterPasswordUnlock, no MasterPasswordAuthentication + // No root Kdf and KdfIterations to trigger both errors + EmailVerificationToken = "token" + }; + + var results = Validate(model); + + Assert.Contains(results, r => r.ErrorMessage == $"{nameof(RegisterFinishRequestModel.Kdf)} not found on RequestModel"); + Assert.Contains(results, r => r.ErrorMessage == $"{nameof(RegisterFinishRequestModel.KdfIterations)} not found on RequestModel"); + } + + [Fact] + public void Validate_WhenAuthAndRootHashBothMissing_ReturnsMissingHashErrorOnly() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + // Both MasterPasswordAuthentication and MasterPasswordHash are missing + MasterPasswordAuthentication = null, + MasterPasswordHash = null, + // Provide valid root KDF to avoid root KDF errors + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = AuthConstants.PBKDF2_ITERATIONS.Default, + EmailVerificationToken = "token" // avoid token error + }; + + var results = Validate(model); + + // Only the new missing hash error should be present + Assert.Single(results); + Assert.Equal($"{nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash)} and {nameof(RegisterFinishRequestModel.MasterPasswordHash)} not found on request, one needs to be defined.", results[0].ErrorMessage); + Assert.Contains(nameof(MasterPasswordAuthenticationDataRequestModel.MasterPasswordAuthenticationHash), results[0].MemberNames); + Assert.Contains(nameof(RegisterFinishRequestModel.MasterPasswordHash), results[0].MemberNames); + } + + [Fact] + public void Validate_WhenAllFieldsValidWithSubModels_IsValid() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterKeyWrappedUserKey = "wrapped", + Salt = "salt" + }, + MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterPasswordAuthenticationHash = "auth-hash", + Salt = "salt" + }, + EmailVerificationToken = "token" + }; + + var results = Validate(model); + + Assert.Empty(results); + } + + [Fact] + public void Validate_WhenNoValidRegistrationTokenProvided_ReturnsTokenErrorOnly() + { + var model = new RegisterFinishRequestModel + { + Email = "user@example.com", + UserAsymmetricKeys = new KeysRequestModel { PublicKey = "pk", EncryptedPrivateKey = "sk" }, + MasterPasswordUnlock = new MasterPasswordUnlockDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterKeyWrappedUserKey = "wrapped", + Salt = "salt" + }, + MasterPasswordAuthentication = new MasterPasswordAuthenticationDataRequestModel + { + Kdf = new KdfRequestModel { KdfType = KdfType.PBKDF2_SHA256, Iterations = AuthConstants.PBKDF2_ITERATIONS.Default }, + MasterPasswordAuthenticationHash = "auth-hash", + Salt = "salt" + } + // No token fields set + }; + + var results = Validate(model); + + Assert.Single(results); + Assert.Equal("No valid registration token provided", results[0].ErrorMessage); + } } diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs index 2f4d00a7fa..ca4378e6ec 100644 --- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs @@ -2,9 +2,9 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyUpdateEvents.Interfaces; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; @@ -13,6 +13,7 @@ using Bit.Core.Auth.Services; using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; @@ -163,7 +164,8 @@ public class SsoConfigServiceTests [Theory, BitAutoData] public async Task SaveAsync_KeyConnector_SingleOrgNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + Organization organization, + [Policy(PolicyType.SingleOrg, false)] PolicyStatus policy) { var utcNow = DateTime.UtcNow; @@ -180,6 +182,9 @@ public class SsoConfigServiceTests RevisionDate = utcNow.AddDays(-10), }; + sutProvider.GetDependency().RunAsync( + Arg.Any(), PolicyType.SingleOrg).Returns(policy); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); @@ -191,7 +196,9 @@ public class SsoConfigServiceTests [Theory, BitAutoData] public async Task SaveAsync_KeyConnector_SsoPolicyNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + Organization organization, + [Policy(PolicyType.SingleOrg, true)] PolicyStatus singleOrgPolicy, + [Policy(PolicyType.RequireSso, false)] PolicyStatus requireSsoPolicy) { var utcNow = DateTime.UtcNow; @@ -208,11 +215,10 @@ public class SsoConfigServiceTests RevisionDate = utcNow.AddDays(-10), }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), PolicyType.SingleOrg).Returns(new Policy - { - Enabled = true - }); + sutProvider.GetDependency().RunAsync( + Arg.Any(), PolicyType.SingleOrg).Returns(singleOrgPolicy); + sutProvider.GetDependency().RunAsync( + Arg.Any(), PolicyType.RequireSso).Returns(requireSsoPolicy); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); @@ -225,7 +231,8 @@ public class SsoConfigServiceTests [Theory, BitAutoData] public async Task SaveAsync_KeyConnector_SsoConfigNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + Organization organization, + [Policy(PolicyType.SingleOrg, true)] PolicyStatus policy) { var utcNow = DateTime.UtcNow; @@ -242,11 +249,8 @@ public class SsoConfigServiceTests RevisionDate = utcNow.AddDays(-10), }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true - }); + sutProvider.GetDependency().RunAsync( + Arg.Any(), Arg.Any()).Returns(policy); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); @@ -259,7 +263,8 @@ public class SsoConfigServiceTests [Theory, BitAutoData] public async Task SaveAsync_KeyConnector_KeyConnectorAbilityNotEnabled_Throws(SutProvider sutProvider, - Organization organization) + Organization organization, + [Policy(PolicyType.SingleOrg, true)] PolicyStatus policy) { var utcNow = DateTime.UtcNow; @@ -277,11 +282,8 @@ public class SsoConfigServiceTests RevisionDate = utcNow.AddDays(-10), }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true, - }); + sutProvider.GetDependency().RunAsync( + Arg.Any(), Arg.Any()).Returns(policy); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.SaveAsync(ssoConfig, organization)); @@ -294,7 +296,8 @@ public class SsoConfigServiceTests [Theory, BitAutoData] public async Task SaveAsync_KeyConnector_Success(SutProvider sutProvider, - Organization organization) + Organization organization, + [Policy(PolicyType.SingleOrg, true)] PolicyStatus policy) { var utcNow = DateTime.UtcNow; @@ -312,11 +315,8 @@ public class SsoConfigServiceTests RevisionDate = utcNow.AddDays(-10), }; - sutProvider.GetDependency().GetByOrganizationIdTypeAsync( - Arg.Any(), Arg.Any()).Returns(new Policy - { - Enabled = true, - }); + sutProvider.GetDependency().RunAsync( + Arg.Any(), Arg.Any()).Returns(policy); await sutProvider.Sut.SaveAsync(ssoConfig, organization); diff --git a/test/Core.Test/Auth/UserFeatures/EmergencyAccess/DeleteEmergencyAccessCommandTests.cs b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/DeleteEmergencyAccessCommandTests.cs new file mode 100644 index 0000000000..057357970b --- /dev/null +++ b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/DeleteEmergencyAccessCommandTests.cs @@ -0,0 +1,253 @@ +using Bit.Core.Auth.Models.Data; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Commands; +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail; +using Bit.Core.Exceptions; +using Bit.Core.Platform.Mail.Mailer; +using Bit.Core.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess; + +[SutProviderCustomize] +public class DeleteEmergencyAccessCommandTests +{ + /// + /// Verifies that attempting to delete a non-existent emergency access record + /// throws a and does not call delete or send email. + /// + [Theory, BitAutoData] + public async Task DeleteByIdGrantorIdAsync_EmergencyAccessNotFound_ThrowsBadRequest( + SutProvider sutProvider, + Guid emergencyAccessId, + Guid grantorId) + { + sutProvider.GetDependency() + .GetDetailsByIdGrantorIdAsync(emergencyAccessId, grantorId) + .Returns((EmergencyAccessDetails)null); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.DeleteByIdGrantorIdAsync(emergencyAccessId, grantorId)); + + Assert.Contains("Emergency Access not valid.", exception.Message); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmail(default); + } + + /// + /// Verifies successful deletion of an emergency access record by ID and grantor ID, + /// and ensures that a notification email is sent to the grantor. + /// + [Theory, BitAutoData] + public async Task DeleteByIdGrantorIdAsync_DeletesEmergencyAccessAndSendsEmail( + SutProvider sutProvider, + EmergencyAccessDetails emergencyAccessDetails) + { + sutProvider.GetDependency() + .GetDetailsByIdGrantorIdAsync(emergencyAccessDetails.Id, emergencyAccessDetails.GrantorId) + .Returns(emergencyAccessDetails); + + var result = await sutProvider.Sut.DeleteByIdGrantorIdAsync(emergencyAccessDetails.Id, emergencyAccessDetails.GrantorId); + + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .SendEmail(Arg.Any()); + } + + /// + /// Verifies that when a grantor has no emergency access records, the method returns + /// an empty collection and does not attempt to delete or send email. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGrantorIdAsync_NoEmergencyAccessRecords_ReturnsEmptyCollection( + SutProvider sutProvider, + Guid grantorId) + { + sutProvider.GetDependency() + .GetManyDetailsByGrantorIdAsync(grantorId) + .Returns([]); + + var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(grantorId); + + Assert.NotNull(result); + Assert.Empty(result); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmail(default); + } + + /// + /// Verifies that when a grantor has multiple emergency access records, all records are deleted, + /// the details are returned, and a single notification email is sent to the grantor. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGrantorIdAsync_MultipleRecords_DeletesAllReturnsDetailsSendsSingleEmail( + SutProvider sutProvider, + EmergencyAccessDetails emergencyAccessDetails1, + EmergencyAccessDetails emergencyAccessDetails2, + EmergencyAccessDetails emergencyAccessDetails3) + { + // Arrange + // link all details to the same grantor + emergencyAccessDetails2.GrantorId = emergencyAccessDetails1.GrantorId; + emergencyAccessDetails2.GrantorEmail = emergencyAccessDetails1.GrantorEmail; + emergencyAccessDetails3.GrantorId = emergencyAccessDetails1.GrantorId; + emergencyAccessDetails3.GrantorEmail = emergencyAccessDetails1.GrantorEmail; + + var allDetails = new List + { + emergencyAccessDetails1, + emergencyAccessDetails2, + emergencyAccessDetails3 + }; + + sutProvider.GetDependency() + .GetManyDetailsByGrantorIdAsync(emergencyAccessDetails1.GrantorId) + .Returns(allDetails); + + // Act + var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(emergencyAccessDetails1.GrantorId); + + // Assert + Assert.NotNull(result); + Assert.Equal(3, result.Count); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .SendEmail(Arg.Any()); + } + + /// + /// Verifies that when a grantor has a single emergency access record, it is deleted, + /// the details are returned, and a notification email is sent. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGrantorIdAsync_SingleRecord_DeletesAndReturnsDetailsSendsSingleEmail( + SutProvider sutProvider, + EmergencyAccessDetails emergencyAccessDetails, + Guid grantorId) + { + sutProvider.GetDependency() + .GetManyDetailsByGrantorIdAsync(grantorId) + .Returns([emergencyAccessDetails]); + + var result = await sutProvider.Sut.DeleteAllByGrantorIdAsync(grantorId); + + Assert.NotNull(result); + Assert.Single(result); + Assert.Equal(emergencyAccessDetails.Id, result.First().Id); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .SendEmail(Arg.Any()); + } + + /// + /// Verifies that when a grantee has no emergency access records, the method returns + /// an empty collection and does not attempt to delete or send email. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGranteeIdAsync_NoEmergencyAccessRecords_ReturnsEmptyCollection( + SutProvider sutProvider, + Guid granteeId) + { + sutProvider.GetDependency() + .GetManyDetailsByGranteeIdAsync(granteeId) + .Returns([]); + + var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync(granteeId); + + Assert.NotNull(result); + Assert.Empty(result); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .SendEmail(default); + } + + /// + /// Verifies that when a grantee has a single emergency access record, it is deleted, + /// the details are returned, and a notification email is sent to the grantor. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGranteeIdAsync_SingleRecord_DeletesAndReturnsDetailsSendsSingleEmail( + SutProvider sutProvider, + EmergencyAccessDetails emergencyAccessDetails, + Guid granteeId) + { + sutProvider.GetDependency() + .GetManyDetailsByGranteeIdAsync(granteeId) + .Returns([emergencyAccessDetails]); + + var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync(granteeId); + + Assert.NotNull(result); + Assert.Single(result); + Assert.Equal(emergencyAccessDetails.Id, result.First().Id); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .Received(1) + .SendEmail(Arg.Any()); + } + + /// + /// Verifies that when a grantee has multiple emergency access records from different grantors, + /// all records are deleted, the details are returned, and a single notification email is sent + /// to all affected grantors individually. + /// + [Theory, BitAutoData] + public async Task DeleteAllByGranteeIdAsync_MultipleRecords_DeletesAllReturnsDetailsSendsMultipleEmails( + SutProvider sutProvider, + EmergencyAccessDetails emergencyAccessDetails1, + EmergencyAccessDetails emergencyAccessDetails2, + EmergencyAccessDetails emergencyAccessDetails3) + { + // link all details to the same grantee + emergencyAccessDetails2.GranteeId = emergencyAccessDetails1.GranteeId; + emergencyAccessDetails2.GranteeEmail = emergencyAccessDetails1.GranteeEmail; + emergencyAccessDetails3.GranteeId = emergencyAccessDetails1.GranteeId; + emergencyAccessDetails3.GranteeEmail = emergencyAccessDetails1.GranteeEmail; + + var allDetails = new List + { + emergencyAccessDetails1, + emergencyAccessDetails2, + emergencyAccessDetails3 + }; + + sutProvider.GetDependency() + .GetManyDetailsByGranteeIdAsync((Guid)emergencyAccessDetails1.GranteeId) + .Returns(allDetails); + + var result = await sutProvider.Sut.DeleteAllByGranteeIdAsync((Guid)emergencyAccessDetails1.GranteeId); + + Assert.NotNull(result); + Assert.Equal(3, result.Count); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Any>()); + await sutProvider.GetDependency() + .Received(allDetails.Count) + .SendEmail(Arg.Any()); + } +} diff --git a/test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessMailTests.cs b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessMailTests.cs new file mode 100644 index 0000000000..60c3644dae --- /dev/null +++ b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessMailTests.cs @@ -0,0 +1,153 @@ +using Bit.Core.Auth.UserFeatures.EmergencyAccess.Mail; +using Bit.Core.Models.Mail; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Mailer; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; +using GlobalSettings = Bit.Core.Settings.GlobalSettings; + +namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess; + +[SutProviderCustomize] +public class EmergencyAccessMailTests +{ + // Constant values for all Emergency Access emails + private const string _emergencyAccessHelpUrl = "https://bitwarden.com/help/emergency-access/"; + private const string _emergencyAccessMailSubject = "Emergency contacts removed"; + + /// + /// Documents how to construct and send the emergency access removal email. + /// 1. Inject IMailer into their command/service + /// 2. Construct EmergencyAccessRemoveGranteesMail as shown below + /// 3. Call mailer.SendEmail(mail) + /// + [Theory, BitAutoData] + public async Task SendEmergencyAccessRemoveGranteesEmail_SingleGrantee_Success( + string grantorEmail, + string granteeEmail) + { + // Arrange + var logger = Substitute.For>(); + var globalSettings = new GlobalSettings { SelfHosted = false }; + var deliveryService = Substitute.For(); + var mailer = new Mailer( + new HandlebarMailRenderer(logger, globalSettings), + deliveryService); + + var mail = new EmergencyAccessRemoveGranteesMail + { + ToEmails = [grantorEmail], + View = new EmergencyAccessRemoveGranteesMailView + { + RemovedGranteeEmails = [granteeEmail] + } + }; + + MailMessage sentMessage = null; + await deliveryService.SendEmailAsync(Arg.Do(message => + sentMessage = message + )); + + // Act + await mailer.SendEmail(mail); + + // Assert + Assert.NotNull(sentMessage); + Assert.Contains(grantorEmail, sentMessage.ToEmails); + + // Verify the content contains the grantee name + Assert.Contains(granteeEmail, sentMessage.TextContent); + Assert.Contains(granteeEmail, sentMessage.HtmlContent); + } + + /// + /// Documents handling multiple removed grantees in a single email. + /// + [Theory, BitAutoData] + public async Task SendEmergencyAccessRemoveGranteesEmail_MultipleGrantees_RendersAllNames( + string grantorEmail) + { + // Arrange + var logger = Substitute.For>(); + var globalSettings = new GlobalSettings { SelfHosted = false }; + var deliveryService = Substitute.For(); + var mailer = new Mailer( + new HandlebarMailRenderer(logger, globalSettings), + deliveryService); + + var granteeEmails = new[] { "Alice@test.dev", "Bob@test.dev", "Carol@test.dev" }; + + var mail = new EmergencyAccessRemoveGranteesMail + { + ToEmails = [grantorEmail], + View = new EmergencyAccessRemoveGranteesMailView + { + RemovedGranteeEmails = granteeEmails + } + }; + + MailMessage sentMessage = null; + await deliveryService.SendEmailAsync(Arg.Do(message => + sentMessage = message + )); + + // Act + await mailer.SendEmail(mail); + + // Assert - All grantee names should appear in the email + Assert.NotNull(sentMessage); + foreach (var granteeEmail in granteeEmails) + { + Assert.Contains(granteeEmail, sentMessage.TextContent); + Assert.Contains(granteeEmail, sentMessage.HtmlContent); + } + } + + /// + /// Validates the required GranteeNames for the email view model. + /// + [Theory, BitAutoData] + public void EmergencyAccessRemoveGranteesMailView_GranteeNames_AreRequired( + string grantorEmail) + { + // Arrange - Shows the minimum required to construct the email + var mail = new EmergencyAccessRemoveGranteesMail + { + ToEmails = [grantorEmail], // Required: who to send to + View = new EmergencyAccessRemoveGranteesMailView + { + // Required: at least one removed grantee name + RemovedGranteeEmails = ["Example Grantee"] + } + }; + + // Assert + Assert.NotNull(mail); + Assert.NotNull(mail.View); + Assert.NotEmpty(mail.View.RemovedGranteeEmails); + } + + /// + /// Ensure consistency with help pages link and email subject. + /// + /// + /// + [Theory, BitAutoData] + public void EmergencyAccessRemoveGranteesMailView_SubjectAndHelpLink_MatchesExpectedValues(string grantorEmail, string granteeName) + { + // Arrange + var mail = new EmergencyAccessRemoveGranteesMail + { + ToEmails = [grantorEmail], + View = new EmergencyAccessRemoveGranteesMailView { RemovedGranteeEmails = [granteeName] } + }; + + // Assert + Assert.NotNull(mail); + Assert.NotNull(mail.View); + Assert.Equal(_emergencyAccessMailSubject, mail.Subject); + Assert.Equal(_emergencyAccessHelpUrl, EmergencyAccessRemoveGranteesMailView.EmergencyAccessHelpPageUrl); + } +} diff --git a/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessServiceTests.cs similarity index 92% rename from test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs rename to test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessServiceTests.cs index 006515aafd..83585e6667 100644 --- a/test/Core.Test/Auth/Services/EmergencyAccessServiceTests.cs +++ b/test/Core.Test/Auth/UserFeatures/EmergencyAccess/EmergencyAccessServiceTests.cs @@ -1,11 +1,10 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Auth.Models.Data; -using Bit.Core.Auth.Services; +using Bit.Core.Auth.UserFeatures.EmergencyAccess; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -17,7 +16,7 @@ using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -namespace Bit.Core.Test.Auth.Services; +namespace Bit.Core.Test.Auth.UserFeatures.EmergencyAccess; [SutProviderCustomize] public class EmergencyAccessServiceTests @@ -68,13 +67,13 @@ public class EmergencyAccessServiceTests Assert.Equal(EmergencyAccessStatusType.Invited, result.Status); await sutProvider.GetDependency() .Received(1) - .CreateAsync(Arg.Any()); + .CreateAsync(Arg.Any()); sutProvider.GetDependency>() .Received(1) .Protect(Arg.Any()); await sutProvider.GetDependency() .Received(1) - .SendEmergencyAccessInviteEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); + .SendEmergencyAccessInviteEmailAsync(Arg.Any(), Arg.Any(), Arg.Any()); } [Theory, BitAutoData] @@ -98,7 +97,7 @@ public class EmergencyAccessServiceTests User invitingUser, Guid emergencyAccessId) { - EmergencyAccess emergencyAccess = null; + Core.Auth.Entities.EmergencyAccess emergencyAccess = null; sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) @@ -119,7 +118,7 @@ public class EmergencyAccessServiceTests User invitingUser, Guid emergencyAccessId) { - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Invited, GrantorId = Guid.NewGuid(), @@ -148,7 +147,7 @@ public class EmergencyAccessServiceTests User invitingUser, Guid emergencyAccessId) { - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = statusType, GrantorId = invitingUser.Id, @@ -172,7 +171,7 @@ public class EmergencyAccessServiceTests User invitingUser, Guid emergencyAccessId) { - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Invited, GrantorId = invitingUser.Id, @@ -194,7 +193,7 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_EmergencyAccessNull_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, string token) { - EmergencyAccess emergencyAccess = null; + Core.Auth.Entities.EmergencyAccess emergencyAccess = null; sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) .Returns(emergencyAccess); @@ -209,7 +208,7 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_CannotUnprotectToken_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string token) { sutProvider.GetDependency() @@ -230,8 +229,8 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_TokenDataInvalid_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, - EmergencyAccess emergencyAccess, - EmergencyAccess wrongEmergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess wrongEmergencyAccess, string token) { sutProvider.GetDependency() @@ -257,7 +256,7 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_AcceptedStatus_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string token) { emergencyAccess.Status = EmergencyAccessStatusType.Accepted; @@ -284,7 +283,7 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_NotInvitedStatus_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string token) { emergencyAccess.Status = EmergencyAccessStatusType.Confirmed; @@ -311,7 +310,7 @@ public class EmergencyAccessServiceTests public async Task AcceptUserAsync_EmergencyAccessEmailDoesNotMatch_ThrowsBadRequest( SutProvider sutProvider, User acceptingUser, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string token) { emergencyAccess.Status = EmergencyAccessStatusType.Invited; @@ -339,7 +338,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User acceptingUser, User invitingUser, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string token) { emergencyAccess.Status = EmergencyAccessStatusType.Invited; @@ -364,7 +363,7 @@ public class EmergencyAccessServiceTests await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Accepted)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Accepted)); await sutProvider.GetDependency() .Received(1) @@ -375,11 +374,11 @@ public class EmergencyAccessServiceTests public async Task DeleteAsync_EmergencyAccessNull_ThrowsBadRequest( SutProvider sutProvider, User invitingUser, - EmergencyAccess emergencyAccess) + Core.Auth.Entities.EmergencyAccess emergencyAccess) { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.DeleteAsync(emergencyAccess.Id, invitingUser.Id)); @@ -391,7 +390,7 @@ public class EmergencyAccessServiceTests public async Task DeleteAsync_EmergencyAccessGrantorIdNotEqual_ThrowsBadRequest( SutProvider sutProvider, User invitingUser, - EmergencyAccess emergencyAccess) + Core.Auth.Entities.EmergencyAccess emergencyAccess) { emergencyAccess.GrantorId = Guid.NewGuid(); sutProvider.GetDependency() @@ -408,7 +407,7 @@ public class EmergencyAccessServiceTests public async Task DeleteAsync_EmergencyAccessGranteeIdNotEqual_ThrowsBadRequest( SutProvider sutProvider, User invitingUser, - EmergencyAccess emergencyAccess) + Core.Auth.Entities.EmergencyAccess emergencyAccess) { emergencyAccess.GranteeId = Guid.NewGuid(); sutProvider.GetDependency() @@ -425,7 +424,7 @@ public class EmergencyAccessServiceTests public async Task DeleteAsync_EmergencyAccessIsDeleted_Success( SutProvider sutProvider, User user, - EmergencyAccess emergencyAccess) + Core.Auth.Entities.EmergencyAccess emergencyAccess) { emergencyAccess.GranteeId = user.Id; emergencyAccess.GrantorId = user.Id; @@ -443,7 +442,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ConfirmUserAsync_EmergencyAccessNull_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string key, User grantorUser) { @@ -451,7 +450,7 @@ public class EmergencyAccessServiceTests emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.ConfirmUserAsync(emergencyAccess.Id, key, grantorUser.Id)); @@ -463,7 +462,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ConfirmUserAsync_EmergencyAccessStatusIsNotAccepted_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string key, User grantorUser) { @@ -484,7 +483,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ConfirmUserAsync_EmergencyAccessGrantorIdNotEqualToConfirmingUserId_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string key, User grantorUser) { @@ -505,7 +504,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User confirmingUser, string key) { confirmingUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Accepted, GrantorId = confirmingUser.Id, @@ -530,7 +529,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ConfirmUserAsync_ConfirmsAndReplacesEmergencyAccess_Success( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, string key, User grantorUser, User granteeUser) @@ -553,7 +552,7 @@ public class EmergencyAccessServiceTests await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); await sutProvider.GetDependency() .Received(1) @@ -564,7 +563,7 @@ public class EmergencyAccessServiceTests public async Task SaveAsync_PremiumCannotUpdate_ThrowsBadRequest( SutProvider sutProvider, User savingUser) { - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Type = EmergencyAccessType.Takeover, GrantorId = savingUser.Id, @@ -586,7 +585,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User savingUser) { savingUser.Premium = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Type = EmergencyAccessType.Takeover, GrantorId = new Guid(), @@ -611,7 +610,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User grantorUser) { grantorUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Type = EmergencyAccessType.Takeover, GrantorId = grantorUser.Id, @@ -633,7 +632,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User grantorUser) { grantorUser.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Type = EmergencyAccessType.View, GrantorId = grantorUser.Id, @@ -655,7 +654,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User grantorUser) { grantorUser.UsesKeyConnector = false; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Type = EmergencyAccessType.Takeover, GrantorId = grantorUser.Id, @@ -678,7 +677,7 @@ public class EmergencyAccessServiceTests { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.InitiateAsync(new Guid(), initiatingUser)); @@ -692,7 +691,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task InitiateAsync_EmergencyAccessGranteeIdNotEqual_ThrowBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User initiatingUser) { emergencyAccess.GranteeId = new Guid(); @@ -712,7 +711,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task InitiateAsync_EmergencyAccessStatusIsNotConfirmed_ThrowBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User initiatingUser) { emergencyAccess.GranteeId = initiatingUser.Id; @@ -735,7 +734,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User initiatingUser, User grantor) { grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Confirmed, GranteeId = initiatingUser.Id, @@ -764,7 +763,7 @@ public class EmergencyAccessServiceTests SutProvider sutProvider, User initiatingUser, User grantor) { grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Confirmed, GranteeId = initiatingUser.Id, @@ -783,14 +782,14 @@ public class EmergencyAccessServiceTests await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); } [Theory, BitAutoData] public async Task InitiateAsync_RequestIsCorrect_Success( SutProvider sutProvider, User initiatingUser, User grantor) { - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { Status = EmergencyAccessStatusType.Confirmed, GranteeId = initiatingUser.Id, @@ -809,7 +808,7 @@ public class EmergencyAccessServiceTests await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryInitiated)); } [Theory, BitAutoData] @@ -818,7 +817,7 @@ public class EmergencyAccessServiceTests { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.ApproveAsync(new Guid(), null)); @@ -829,7 +828,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ApproveAsync_EmergencyAccessGrantorIdNotEquatToApproving_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User grantorUser) { emergencyAccess.Status = EmergencyAccessStatusType.RecoveryInitiated; @@ -851,7 +850,7 @@ public class EmergencyAccessServiceTests public async Task ApproveAsync_EmergencyAccessStatusNotRecoveryInitiated_ThrowsBadRequest( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User grantorUser) { emergencyAccess.GrantorId = grantorUser.Id; @@ -869,7 +868,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ApproveAsync_Success( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User grantorUser, User granteeUser) { @@ -885,20 +884,20 @@ public class EmergencyAccessServiceTests await sutProvider.Sut.ApproveAsync(emergencyAccess.Id, grantorUser); await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryApproved)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.RecoveryApproved)); } [Theory, BitAutoData] public async Task RejectAsync_EmergencyAccessIdNull_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User GrantorUser) { emergencyAccess.GrantorId = GrantorUser.Id; emergencyAccess.Status = EmergencyAccessStatusType.Accepted; sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RejectAsync(emergencyAccess.Id, GrantorUser)); @@ -909,7 +908,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task RejectAsync_EmergencyAccessGrantorIdNotEqualToRequestUser_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User GrantorUser) { emergencyAccess.Status = EmergencyAccessStatusType.Accepted; @@ -930,7 +929,7 @@ public class EmergencyAccessServiceTests public async Task RejectAsync_EmergencyAccessStatusNotValid_ThrowsBadRequest( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User GrantorUser) { emergencyAccess.GrantorId = GrantorUser.Id; @@ -951,7 +950,7 @@ public class EmergencyAccessServiceTests public async Task RejectAsync_Success( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User GrantorUser, User GranteeUser) { @@ -968,7 +967,7 @@ public class EmergencyAccessServiceTests await sutProvider.GetDependency() .Received(1) - .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); + .ReplaceAsync(Arg.Is(x => x.Status == EmergencyAccessStatusType.Confirmed)); } [Theory, BitAutoData] @@ -977,7 +976,7 @@ public class EmergencyAccessServiceTests { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.GetPoliciesAsync(default, default)); @@ -992,7 +991,7 @@ public class EmergencyAccessServiceTests public async Task GetPoliciesAsync_RequestNotValidStatusType_ThrowsBadRequest( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1010,7 +1009,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task GetPoliciesAsync_RequestNotValidType_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1032,7 +1031,7 @@ public class EmergencyAccessServiceTests public async Task GetPoliciesAsync_OrganizationUserTypeNotOwner_ReturnsNull( OrganizationUserType userType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser, OrganizationUser grantorOrganizationUser) @@ -1062,7 +1061,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task GetPoliciesAsync_OrganizationUserEmpty_ReturnsNull( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser) { @@ -1090,7 +1089,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task GetPoliciesAsync_ReturnsNotNull( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser, OrganizationUser grantorOrganizationUser) @@ -1127,7 +1126,7 @@ public class EmergencyAccessServiceTests { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.TakeoverAsync(default, default)); @@ -1138,7 +1137,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task TakeoverAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; @@ -1161,7 +1160,7 @@ public class EmergencyAccessServiceTests public async Task TakeoverAsync_RequestNotValid_StatusType_ThrowsBadRequest( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1180,7 +1179,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task TakeoverAsync_RequestNotValid_TypeIsView_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1203,7 +1202,7 @@ public class EmergencyAccessServiceTests User grantor) { grantor.UsesKeyConnector = true; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { GrantorId = grantor.Id, GranteeId = granteeUser.Id, @@ -1232,7 +1231,7 @@ public class EmergencyAccessServiceTests User grantor) { grantor.UsesKeyConnector = false; - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { GrantorId = grantor.Id, GranteeId = granteeUser.Id, @@ -1260,7 +1259,7 @@ public class EmergencyAccessServiceTests { sutProvider.GetDependency() .GetByIdAsync(Arg.Any()) - .Returns((EmergencyAccess)null); + .Returns((Core.Auth.Entities.EmergencyAccess)null); var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.PasswordAsync(default, default, default, default)); @@ -1271,7 +1270,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task PasswordAsync_RequestNotValid_GranteeNotEqualToRequestingUser_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.Status = EmergencyAccessStatusType.RecoveryApproved; @@ -1294,7 +1293,7 @@ public class EmergencyAccessServiceTests public async Task PasswordAsync_RequestNotValid_StatusType_ThrowsBadRequest( EmergencyAccessStatusType statusType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1313,7 +1312,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task PasswordAsync_RequestNotValid_TypeIsView_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1332,7 +1331,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task PasswordAsync_NonOrgUser_Success( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser, string key, @@ -1367,7 +1366,7 @@ public class EmergencyAccessServiceTests public async Task PasswordAsync_OrgUser_NotOrganizationOwner_RemovedFromOrganization_Success( OrganizationUserType userType, SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser, OrganizationUser organizationUser, @@ -1408,7 +1407,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task PasswordAsync_OrgUser_IsOrganizationOwner_NotRemovedFromOrganization_Success( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser, User grantorUser, OrganizationUser organizationUser, @@ -1459,7 +1458,7 @@ public class EmergencyAccessServiceTests Enabled = true } }); - var emergencyAccess = new EmergencyAccess + var emergencyAccess = new Core.Auth.Entities.EmergencyAccess { GrantorId = grantor.Id, GranteeId = requestingUser.Id, @@ -1484,7 +1483,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task ViewAsync_EmergencyAccessTypeNotView_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; @@ -1500,7 +1499,7 @@ public class EmergencyAccessServiceTests [Theory, BitAutoData] public async Task GetAttachmentDownloadAsync_EmergencyAccessTypeNotView_ThrowsBadRequest( SutProvider sutProvider, - EmergencyAccess emergencyAccess, + Core.Auth.Entities.EmergencyAccess emergencyAccess, User granteeUser) { emergencyAccess.GranteeId = granteeUser.Id; diff --git a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs index ae669398c5..29193bacbc 100644 --- a/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs +++ b/test/Core.Test/Auth/UserFeatures/Registration/RegisterUserCommandTests.cs @@ -1,8 +1,8 @@ using System.Text; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Auth.Entities; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models; using Bit.Core.Auth.Models.Business.Tokenables; @@ -14,6 +14,7 @@ using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterpri using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Tokens; using Bit.Core.Utilities; using Bit.Test.Common.AutoFixture; @@ -23,6 +24,7 @@ using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.WebUtilities; using NSubstitute; using Xunit; +using EmergencyAccessEntity = Bit.Core.Auth.Entities.EmergencyAccess; namespace Bit.Core.Test.Auth.UserFeatures.Registration; @@ -241,7 +243,8 @@ public class RegisterUserCommandTests [BitAutoData(true, "sampleInitiationPath")] [BitAutoData(true, "Secrets Manager trial")] public async Task RegisterUserViaOrganizationInviteToken_ComplexHappyPath_Succeeds(bool addUserReferenceData, string initiationPath, - SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId, Policy twoFactorPolicy) + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId, + [Policy(PolicyType.TwoFactorAuthentication, true)] PolicyStatus policy) { // Arrange sutProvider.GetDependency() @@ -267,10 +270,9 @@ public class RegisterUserCommandTests .GetByIdAsync(orgUserId) .Returns(orgUser); - twoFactorPolicy.Enabled = true; - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication) - .Returns(twoFactorPolicy); + sutProvider.GetDependency() + .RunAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication) + .Returns(policy); sutProvider.GetDependency() .CreateUserAsync(user, masterPasswordHash) @@ -286,9 +288,9 @@ public class RegisterUserCommandTests .Received(1) .GetByIdAsync(orgUserId); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) - .GetByOrganizationIdTypeAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication); + .RunAsync(orgUser.OrganizationId, PolicyType.TwoFactorAuthentication); sutProvider.GetDependency() .Received(1) @@ -431,7 +433,8 @@ public class RegisterUserCommandTests [Theory] [BitAutoData] public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromDifferentOrg_ThrowsBadRequestException( - SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId, + [Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy) { // Arrange user.Email = "user@blocked-domain.com"; @@ -463,6 +466,10 @@ public class RegisterUserCommandTests .HasVerifiedDomainWithBlockClaimedDomainPolicyAsync("blocked-domain.com", orgUser.OrganizationId) .Returns(true); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns(policy); + // Act & Assert var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId)); @@ -472,7 +479,8 @@ public class RegisterUserCommandTests [Theory] [BitAutoData] public async Task RegisterUserViaOrganizationInviteToken_BlockedDomainFromSameOrg_Succeeds( - SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId) + SutProvider sutProvider, User user, string masterPasswordHash, OrganizationUser orgUser, string orgInviteToken, Guid orgUserId, + [Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy) { // Arrange user.Email = "user@company-domain.com"; @@ -509,6 +517,10 @@ public class RegisterUserCommandTests .CreateUserAsync(user, masterPasswordHash) .Returns(IdentityResult.Success); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns(policy); + // Act var result = await sutProvider.Sut.RegisterUserViaOrganizationInviteToken(user, masterPasswordHash, orgInviteToken, orgUserId); @@ -726,7 +738,7 @@ public class RegisterUserCommandTests [BitAutoData] public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_Succeeds( SutProvider sutProvider, User user, string masterPasswordHash, - EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange user.Email = $"test+{Guid.NewGuid()}@example.com"; @@ -767,7 +779,7 @@ public class RegisterUserCommandTests [Theory] [BitAutoData] public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_InvalidToken_ThrowsBadRequestException(SutProvider sutProvider, User user, - string masterPasswordHash, EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + string masterPasswordHash, EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange user.Email = $"test+{Guid.NewGuid()}@example.com"; @@ -1112,7 +1124,7 @@ public class RegisterUserCommandTests [BitAutoData] public async Task RegisterUserViaAcceptEmergencyAccessInviteToken_BlockedDomain_ThrowsBadRequestException( SutProvider sutProvider, User user, string masterPasswordHash, - EmergencyAccess emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) + EmergencyAccessEntity emergencyAccess, string acceptEmergencyAccessInviteToken, Guid acceptEmergencyAccessId) { // Arrange user.Email = "user@blocked-domain.com"; @@ -1245,6 +1257,7 @@ public class RegisterUserCommandTests OrganizationUser orgUser, string orgInviteToken, string masterPasswordHash, + [Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy, SutProvider sutProvider) { // Arrange @@ -1259,9 +1272,9 @@ public class RegisterUserCommandTests .GetByIdAsync(orgUser.Id) .Returns(orgUser); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) - .Returns((Policy)null); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns(policy); sutProvider.GetDependency() .GetByIdAsync(orgUser.OrganizationId) @@ -1331,6 +1344,7 @@ public class RegisterUserCommandTests OrganizationUser orgUser, string masterPasswordHash, string orgInviteToken, + [Policy(PolicyType.TwoFactorAuthentication, false)] PolicyStatus policy, SutProvider sutProvider) { // Arrange @@ -1346,9 +1360,9 @@ public class RegisterUserCommandTests .GetByIdAsync(orgUser.Id) .Returns(orgUser); - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) - .Returns((Policy)null); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), PolicyType.TwoFactorAuthentication) + .Returns(policy); sutProvider.GetDependency() .GetByIdAsync(orgUser.OrganizationId) diff --git a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs index da42127f33..7643510e74 100644 --- a/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs +++ b/test/Core.Test/Billing/Payment/Commands/UpdatePaymentMethodCommandTests.cs @@ -4,6 +4,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Commands; using Bit.Core.Billing.Payment.Models; using Bit.Core.Billing.Services; +using Bit.Core.Services; using Bit.Core.Settings; using Bit.Core.Test.Billing.Extensions; using Braintree; @@ -22,6 +23,7 @@ using static StripeConstants; public class UpdatePaymentMethodCommandTests { private readonly IBraintreeGateway _braintreeGateway = Substitute.For(); + private readonly IBraintreeService _braintreeService = Substitute.For(); private readonly IGlobalSettings _globalSettings = Substitute.For(); private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); @@ -32,6 +34,7 @@ public class UpdatePaymentMethodCommandTests { _command = new UpdatePaymentMethodCommand( _braintreeGateway, + _braintreeService, _globalSettings, Substitute.For>(), _setupIntentCache, @@ -375,7 +378,6 @@ public class UpdatePaymentMethodCommandTests _subscriberService.GetCustomer(organization).Returns(customer); - var customerGateway = Substitute.For(); var braintreeCustomer = Substitute.For(); braintreeCustomer.Id.Returns("braintree_customer_id"); var existing = Substitute.For(); @@ -383,7 +385,10 @@ public class UpdatePaymentMethodCommandTests existing.IsDefault.Returns(true); existing.Token.Returns("EXISTING"); braintreeCustomer.PaymentMethods.Returns([existing]); - customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer); + + _braintreeService.GetCustomer(customer).Returns(braintreeCustomer); + + var customerGateway = Substitute.For(); _braintreeGateway.Customer.Returns(customerGateway); var paymentMethodGateway = Substitute.For(); @@ -471,4 +476,75 @@ public class UpdatePaymentMethodCommandTests Arg.Is(options => options.Metadata[MetadataKeys.BraintreeCustomerId] == "braintree_customer_id")); } + + [Fact] + public async Task Run_PayPal_MissingBraintreeCustomer_CreatesNewBraintreeCustomer_ReturnsMaskedPayPalAccount() + { + var organization = new Organization + { + Id = Guid.NewGuid(), + GatewayCustomerId = "cus_123" + }; + + var customer = new Customer + { + Address = new Address + { + Country = "US", + PostalCode = "12345" + }, + Id = "cus_123", + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "missing_braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization).Returns(customer); + + // BraintreeService.GetCustomer returns null when the Braintree customer doesn't exist + _braintreeService.GetCustomer(customer).Returns((Braintree.Customer?)null); + + _globalSettings.BaseServiceUri.Returns(new GlobalSettings.BaseServiceUriSettings(new GlobalSettings()) + { + CloudRegion = "US" + }); + + var customerGateway = Substitute.For(); + var braintreeCustomer = Substitute.For(); + braintreeCustomer.Id.Returns("new_braintree_customer_id"); + var payPalAccount = Substitute.For(); + payPalAccount.Email.Returns("user@gmail.com"); + payPalAccount.IsDefault.Returns(true); + payPalAccount.Token.Returns("NONCE"); + braintreeCustomer.PaymentMethods.Returns([payPalAccount]); + var createResult = Substitute.For>(); + createResult.Target.Returns(braintreeCustomer); + customerGateway.CreateAsync(Arg.Is(options => + options.Id.StartsWith(organization.BraintreeCustomerIdPrefix() + organization.Id.ToString("N").ToLower()) && + options.CustomFields[organization.BraintreeIdField()] == organization.Id.ToString() && + options.CustomFields[organization.BraintreeCloudRegionField()] == "US" && + options.Email == organization.BillingEmailAddress() && + options.PaymentMethodNonce == "TOKEN")).Returns(createResult); + _braintreeGateway.Customer.Returns(customerGateway); + + var result = await _command.Run(organization, + new TokenizedPaymentMethod { Type = TokenizablePaymentMethodType.PayPal, Token = "TOKEN" }, + new BillingAddress { Country = "US", PostalCode = "12345" }); + + Assert.True(result.IsT0); + var maskedPaymentMethod = result.AsT0; + Assert.True(maskedPaymentMethod.IsT2); + var maskedPayPalAccount = maskedPaymentMethod.AsT2; + Assert.Equal("user@gmail.com", maskedPayPalAccount.Email); + + // Verify a new Braintree customer was created (not FindAsync called) + await customerGateway.DidNotReceive().FindAsync(Arg.Any()); + await customerGateway.Received(1).CreateAsync(Arg.Any()); + + // Verify Stripe metadata was updated with the new Braintree customer ID + await _stripeAdapter.Received(1).UpdateCustomerAsync(customer.Id, + Arg.Is(options => + options.Metadata[MetadataKeys.BraintreeCustomerId] == "new_braintree_customer_id")); + } } diff --git a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs index 4e4c5199e2..8c65bf68be 100644 --- a/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs +++ b/test/Core.Test/Billing/Payment/Queries/GetPaymentMethodQueryTests.cs @@ -3,9 +3,9 @@ using Bit.Core.Billing.Caches; using Bit.Core.Billing.Constants; using Bit.Core.Billing.Payment.Queries; using Bit.Core.Billing.Services; +using Bit.Core.Services; using Bit.Core.Test.Billing.Extensions; using Braintree; -using Microsoft.Extensions.Logging; using NSubstitute; using NSubstitute.ReturnsExtensions; using Stripe; @@ -19,7 +19,7 @@ using static StripeConstants; public class GetPaymentMethodQueryTests { - private readonly IBraintreeGateway _braintreeGateway = Substitute.For(); + private readonly IBraintreeService _braintreeService = Substitute.For(); private readonly ISetupIntentCache _setupIntentCache = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly ISubscriberService _subscriberService = Substitute.For(); @@ -28,8 +28,7 @@ public class GetPaymentMethodQueryTests public GetPaymentMethodQueryTests() { _query = new GetPaymentMethodQuery( - _braintreeGateway, - Substitute.For>(), + _braintreeService, _setupIntentCache, _stripeAdapter, _subscriberService); @@ -75,6 +74,34 @@ public class GetPaymentMethodQueryTests Assert.Null(maskedPaymentMethod); } + [Fact] + public async Task Run_NoPaymentMethod_BraintreeCustomerNotFound_ReturnsNull() + { + var organization = new Organization + { + Id = Guid.NewGuid() + }; + + var customer = new Customer + { + InvoiceSettings = new CustomerInvoiceSettings(), + Metadata = new Dictionary + { + [MetadataKeys.BraintreeCustomerId] = "non_existent_braintree_customer_id" + } + }; + + _subscriberService.GetCustomer(organization, + Arg.Is(options => + options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); + + _braintreeService.GetCustomer(customer).ReturnsNull(); + + var maskedPaymentMethod = await _query.Run(organization); + + Assert.Null(maskedPaymentMethod); + } + [Fact] public async Task Run_BankAccount_FromPaymentMethod_ReturnsMaskedBankAccount() { @@ -328,14 +355,12 @@ public class GetPaymentMethodQueryTests Arg.Is(options => options.HasExpansions("default_source", "invoice_settings.default_payment_method"))).Returns(customer); - var customerGateway = Substitute.For(); var braintreeCustomer = Substitute.For(); var payPalAccount = Substitute.For(); payPalAccount.Email.Returns("user@gmail.com"); payPalAccount.IsDefault.Returns(true); braintreeCustomer.PaymentMethods.Returns([payPalAccount]); - customerGateway.FindAsync("braintree_customer_id").Returns(braintreeCustomer); - _braintreeGateway.Customer.Returns(customerGateway); + _braintreeService.GetCustomer(customer).Returns(braintreeCustomer); var maskedPaymentMethod = await _query.Run(organization); diff --git a/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs new file mode 100644 index 0000000000..c2af07f633 --- /dev/null +++ b/test/Core.Test/Billing/Premium/Commands/PreviewPremiumUpgradeProrationCommandTests.cs @@ -0,0 +1,777 @@ +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Payment.Models; +using Bit.Core.Billing.Premium.Commands; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Services; +using Bit.Core.Entities; +using Bit.Core.Test.Billing.Mocks.Plans; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Stripe; +using Xunit; +using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; + +namespace Bit.Core.Test.Billing.Premium.Commands; + +public class PreviewPremiumUpgradeProrationCommandTests +{ + private readonly ILogger _logger = Substitute.For>(); + private readonly IPricingClient _pricingClient = Substitute.For(); + private readonly IStripeAdapter _stripeAdapter = Substitute.For(); + private readonly PreviewPremiumUpgradeProrationCommand _command; + + public PreviewPremiumUpgradeProrationCommandTests() + { + _command = new PreviewPremiumUpgradeProrationCommand( + _logger, + _pricingClient, + _stripeAdapter); + } + + [Theory, BitAutoData] + public async Task Run_UserWithoutPremium_ReturnsBadRequest(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = false; + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User does not have an active Premium subscription.", badRequest.Response); + } + + [Theory, BitAutoData] + public async Task Run_UserWithoutGatewaySubscriptionId_ReturnsBadRequest(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = null; + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT1); + var badRequest = result.AsT1; + Assert.Equal("User does not have an active Premium subscription.", badRequest.Response); + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_ReturnsProrationAmounts(User user, BillingAddress billingAddress) + { + // Arrange - Setup valid Premium user + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + // Setup Premium plans + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + + var premiumPlans = new List { premiumPlan }; + + // Setup current Stripe subscription + var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc); + var currentPeriodEnd = now.AddMonths(6); + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer + { + Id = "cus_123", + Discount = null + }, + Items = new StripeList + { + Data = new List + { + new() + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" }, + CurrentPeriodEnd = currentPeriodEnd + } + } + } + }; + + // Setup target organization plan + var targetPlan = new TeamsPlan(isAnnual: true); + + // Setup invoice preview response + var invoice = new Invoice + { + Total = 5000, // $50.00 + TotalTaxes = new List + { + new() { Amount = 500 } // $5.00 + }, + Lines = new StripeList + { + Data = new List + { + new() { Amount = 5000 } // $50.00 for new plan + } + }, + PeriodEnd = now + }; + + // Configure mocks + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync( + "sub_123", + Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + var proration = result.AsT0; + Assert.Equal(50.00m, proration.NewPlanProratedAmount); + Assert.Equal(0m, proration.Credit); + Assert.Equal(5.00m, proration.Tax); + Assert.Equal(50.00m, proration.Total); + Assert.Equal(6, proration.NewPlanProratedMonths); // 6 months remaining + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_ExtractsProrationCredit(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + // Use fixed time to avoid DateTime.UtcNow differences + var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc); + var currentPeriodEnd = now.AddDays(45); // 1.5 months ~ 2 months rounded + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + // Invoice with negative line item (proration credit) + var invoice = new Invoice + { + Total = 4000, // $40.00 + TotalTaxes = new List { new() { Amount = 400 } }, // $4.00 + Lines = new StripeList + { + Data = new List + { + new() { Amount = -1000 }, // -$10.00 credit from unused Premium + new() { Amount = 5000 } // $50.00 for new plan + } + }, + PeriodEnd = now + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + var proration = result.AsT0; + Assert.Equal(50.00m, proration.NewPlanProratedAmount); + Assert.Equal(10.00m, proration.Credit); // Proration credit + Assert.Equal(4.00m, proration.Tax); + Assert.Equal(40.00m, proration.Total); + Assert.Equal(2, proration.NewPlanProratedMonths); // 45 days rounds to 2 months + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_AlwaysUsesOneSeat(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList { Data = new List { new() { Amount = 5000 } } }, + PeriodEnd = DateTime.UtcNow + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert - Verify that the subscription item quantity is always 1 and has Id + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(options => + options.SubscriptionDetails.Items.Any(item => + item.Id == "si_premium" && + item.Price == targetPlan.PasswordManager.StripeSeatPlanId && + item.Quantity == 1))); + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_DeletesPremiumSubscriptionItems(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_password_manager", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) }, + new() { Id = "si_storage", Price = new Price { Id = "storage-gb-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList { Data = new List { new() { Amount = 5000 } } }, + PeriodEnd = DateTime.UtcNow + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert - Verify password manager item is modified and storage item is deleted + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(options => + // Password manager item should be modified to new plan price, not deleted + options.SubscriptionDetails.Items.Any(item => + item.Id == "si_password_manager" && + item.Price == targetPlan.PasswordManager.StripeSeatPlanId && + item.Deleted != true) && + // Storage item should be deleted + options.SubscriptionDetails.Items.Any(item => + item.Id == "si_storage" && item.Deleted == true))); + } + + [Theory, BitAutoData] + public async Task Run_NonSeatBasedPlan_UsesStripePlanId(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) } + } + } + }; + + var targetPlan = new FamiliesPlan(); // families is non seat based + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList { Data = new List { new() { Amount = 5000 } } }, + PeriodEnd = DateTime.UtcNow + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.FamiliesAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + await _command.Run(user, PlanType.FamiliesAnnually, billingAddress); + + // Assert - Verify non-seat-based plan uses StripePlanId with quantity 1 and modifies existing item + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(options => + options.SubscriptionDetails.Items.Any(item => + item.Id == "si_premium" && + item.Price == targetPlan.PasswordManager.StripePlanId && + item.Quantity == 1))); + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_CreatesCorrectInvoicePreviewOptions(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + billingAddress.Country = "US"; + billingAddress.PostalCode = "12345"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList { Data = new List { new() { Amount = 5000 } } }, + PeriodEnd = DateTime.UtcNow + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert - Verify all invoice preview options are correct + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(options => + options.AutomaticTax.Enabled == true && + options.Customer == "cus_123" && + options.Subscription == "sub_123" && + options.CustomerDetails.Address.Country == "US" && + options.CustomerDetails.Address.PostalCode == "12345" && + options.SubscriptionDetails.ProrationBehavior == "always_invoice")); + } + + [Theory, BitAutoData] + public async Task Run_SeatBasedPlan_UsesStripeSeatPlanId(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = DateTime.UtcNow.AddMonths(1) } + } + } + }; + + // Use Teams which is seat-based + var targetPlan = new TeamsPlan(isAnnual: true); + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList { Data = new List { new() { Amount = 5000 } } }, + PeriodEnd = DateTime.UtcNow + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert - Verify seat-based plan uses StripeSeatPlanId with quantity 1 and modifies existing item + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(options => + options.SubscriptionDetails.Items.Any(item => + item.Id == "si_premium" && + item.Price == targetPlan.PasswordManager.StripeSeatPlanId && + item.Quantity == 1))); + } + + [Theory] + [InlineData(0, 1)] // Less than 15 days, minimum 1 month + [InlineData(1, 1)] // 1 day = 1 month minimum + [InlineData(14, 1)] // 14 days = 1 month minimum + [InlineData(15, 1)] // 15 days rounds to 1 month + [InlineData(30, 1)] // 30 days = 1 month + [InlineData(44, 1)] // 44 days rounds to 1 month + [InlineData(45, 2)] // 45 days rounds to 2 months + [InlineData(60, 2)] // 60 days = 2 months + [InlineData(90, 3)] // 90 days = 3 months + [InlineData(180, 6)] // 180 days = 6 months + [InlineData(365, 12)] // 365 days rounds to 12 months + public async Task Run_ValidUpgrade_CalculatesNewPlanProratedMonthsCorrectly(int daysRemaining, int expectedMonths) + { + // Arrange + var user = new User + { + Premium = true, + GatewaySubscriptionId = "sub_123", + GatewayCustomerId = "cus_123" + }; + var billingAddress = new Core.Billing.Payment.Models.BillingAddress + { + Country = "US", + PostalCode = "12345" + }; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + // Use fixed time to avoid DateTime.UtcNow differences + var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc); + var currentPeriodEnd = now.AddDays(daysRemaining); + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + var invoice = new Invoice + { + Total = 5000, + TotalTaxes = new List { new() { Amount = 500 } }, + Lines = new StripeList + { + Data = new List { new() { Amount = 5000 } } + }, + PeriodEnd = now + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + var proration = result.AsT0; + Assert.Equal(expectedMonths, proration.NewPlanProratedMonths); + } + + [Theory, BitAutoData] + public async Task Run_ValidUpgrade_ReturnsNewPlanProratedAmountCorrectly(User user, BillingAddress billingAddress) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var premiumPlan = new PremiumPlan + { + Name = "Premium", + Available = true, + LegacyYear = null, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "premium-annually", + Price = 10m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "storage-gb-annually", + Price = 4m, + Provided = 1 + } + }; + var premiumPlans = new List { premiumPlan }; + + var now = new DateTime(2026, 1, 1, 0, 0, 0, DateTimeKind.Utc); + var currentPeriodEnd = now.AddMonths(3); + var currentSubscription = new Subscription + { + Id = "sub_123", + Customer = new Customer { Id = "cus_123", Discount = null }, + Items = new StripeList + { + Data = new List + { + new() { Id = "si_premium", Price = new Price { Id = "premium-annually" }, CurrentPeriodEnd = currentPeriodEnd } + } + } + }; + + var targetPlan = new TeamsPlan(isAnnual: true); + + // Invoice showing new plan cost, credit, and net + var invoice = new Invoice + { + Total = 4500, // $45.00 net after $5 credit + TotalTaxes = new List { new() { Amount = 450 } }, // $4.50 + Lines = new StripeList + { + Data = new List + { + new() { Amount = -500 }, // -$5.00 credit + new() { Amount = 5000 } // $50.00 for new plan + } + }, + PeriodEnd = now + }; + + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(targetPlan); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()) + .Returns(currentSubscription); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(invoice); + + // Act + var result = await _command.Run(user, PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + var proration = result.AsT0; + + Assert.Equal(50.00m, proration.NewPlanProratedAmount); + Assert.Equal(5.00m, proration.Credit); + Assert.Equal(4.50m, proration.Tax); + Assert.Equal(45.00m, proration.Total); + } +} + diff --git a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs index 7b9b68c757..cd9b323f9d 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; +using static Bit.Core.Billing.Constants.StripeConstants; using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; @@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands; public class UpdatePremiumStorageCommandTests { + private readonly IBraintreeService _braintreeService = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly IUserService _userService = Substitute.For(); private readonly IPricingClient _pricingClient = Substitute.For(); @@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests _pricingClient.ListPremiumPlans().Returns([premiumPlan]); _command = new UpdatePremiumStorageCommand( + _braintreeService, _stripeAdapter, _userService, _pricingClient, Substitute.For>()); } - private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null) + private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false) { var items = new List { @@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests }); } + var customer = new Customer + { + Id = "cus_123", + Metadata = isPayPal ? new Dictionary { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary() + }; + return new Subscription { Id = subscriptionId, + CustomerId = "cus_123", + Customer = customer, Items = new StripeList { Data = items @@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, -5); @@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 100); @@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 4); @@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests Assert.True(result.IsT0); // Verify subscription was fetched but NOT updated - await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123"); + await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any()); await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); } @@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123"); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 2); @@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 99); @@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 100)); } + + [Theory, BitAutoData] + public async Task Run_IncreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 5; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with CreateProrations + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify draft invoice was created + await _stripeAdapter.Received(1).CreateInvoiceAsync( + Arg.Is(opts => + opts.Customer == "cus_123" && + opts.Subscription == "sub_123" && + opts.AutoAdvance == false && + opts.CollectionMethod == "charge_automatically")); + + // Verify invoice was finalized + await _stripeAdapter.Received(1).FinalizeInvoiceAsync( + "in_draft", + Arg.Is(opts => + opts.AutoAdvance == false && + opts.Expand.Contains("customer"))); + + // Verify Braintree payment was processed + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + // Verify user was saved + await _userService.Received(1).SaveUserAsync(Arg.Is(u => + u.Id == user.Id && + u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_AddStorageFromZero_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 1; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with new storage item + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Price == "price_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_DecreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 2); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 2 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 3)); + } + + [Theory, BitAutoData] + public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 0); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription item was deleted + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Deleted == true && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 1)); + } } diff --git a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs index e686d04009..b4fd0e2d21 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpgradePremiumToOrganizationCommandTests.cs @@ -37,7 +37,6 @@ public class UpgradePremiumToOrganizationCommandTests NameLocalizationKey = ""; DescriptionLocalizationKey = ""; CanBeUsedByBusiness = true; - TrialPeriodDays = null; HasSelfHost = false; HasPolicies = false; HasGroups = false; @@ -86,10 +85,8 @@ public class UpgradePremiumToOrganizationCommandTests string? stripePlanId = null, string? stripeSeatPlanId = null, string? stripePremiumAccessPlanId = null, - string? stripeStoragePlanId = null) - { - return new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId); - } + string? stripeStoragePlanId = null) => + new TestPlan(planType, stripePlanId, stripeSeatPlanId, stripePremiumAccessPlanId, stripeStoragePlanId); private static PremiumPlan CreateTestPremiumPlan( string seatPriceId = "premium-annually", @@ -151,6 +148,9 @@ public class UpgradePremiumToOrganizationCommandTests _applicationCacheService); } + private static Core.Billing.Payment.Models.BillingAddress CreateTestBillingAddress() => + new() { Country = "US", PostalCode = "12345" }; + [Theory, BitAutoData] public async Task Run_UserNotPremium_ReturnsBadRequest(User user) { @@ -158,7 +158,7 @@ public class UpgradePremiumToOrganizationCommandTests user.Premium = false; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -174,7 +174,7 @@ public class UpgradePremiumToOrganizationCommandTests user.GatewaySubscriptionId = null; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -190,7 +190,7 @@ public class UpgradePremiumToOrganizationCommandTests user.GatewaySubscriptionId = ""; // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); @@ -245,7 +245,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -253,9 +253,8 @@ public class UpgradePremiumToOrganizationCommandTests await _stripeAdapter.Received(1).UpdateSubscriptionAsync( "sub_123", Arg.Is(opts => - opts.Items.Count == 2 && // 1 deleted + 1 seat (no storage) - opts.Items.Any(i => i.Deleted == true) && - opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1))); + opts.Items.Count == 1 && // Only 1 item: modify existing password manager item (no storage to delete) + opts.Items.Any(i => i.Id == "si_premium" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true))); await _organizationRepository.Received(1).CreateAsync(Arg.Is(o => o.Name == "My Organization" && @@ -320,7 +319,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually); + var result = await _command.Run(user, "My Families Org", "encrypted-key", PlanType.FamiliesAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -328,9 +327,8 @@ public class UpgradePremiumToOrganizationCommandTests await _stripeAdapter.Received(1).UpdateSubscriptionAsync( "sub_123", Arg.Is(opts => - opts.Items.Count == 2 && // 1 deleted + 1 plan - opts.Items.Any(i => i.Deleted == true) && - opts.Items.Any(i => i.Price == "families-plan-annually" && i.Quantity == 1))); + opts.Items.Count == 1 && // Only 1 item: modify existing password manager item (no storage to delete) + opts.Items.Any(i => i.Id == "si_premium" && i.Price == "families-plan-annually" && i.Quantity == 1 && i.Deleted != true))); await _organizationRepository.Received(1).CreateAsync(Arg.Is(o => o.Name == "My Families Org")); @@ -383,7 +381,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -392,11 +390,6 @@ public class UpgradePremiumToOrganizationCommandTests "sub_123", Arg.Is(opts => opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.OrganizationId) && - opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPremiumPriceId) && - opts.Metadata[StripeConstants.MetadataKeys.PreviousPremiumPriceId] == "premium-annually" && - opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousPeriodEndDate) && - opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) && - opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "0" && opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.UserId) && opts.Metadata[StripeConstants.MetadataKeys.UserId] == string.Empty)); // Removes userId to unlink from User } @@ -453,19 +446,18 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); - // Verify that BOTH legacy items (password manager + storage) are deleted by ID + // Verify that legacy password manager item is modified and legacy storage is deleted await _stripeAdapter.Received(1).UpdateSubscriptionAsync( "sub_123", Arg.Is(opts => - opts.Items.Count == 3 && // 2 deleted (legacy PM + legacy storage) + 1 new seat - opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium_legacy") == 1 && // Legacy PM deleted - opts.Items.Count(i => i.Deleted == true && i.Id == "si_storage_legacy") == 1 && // Legacy storage deleted - opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1))); + opts.Items.Count == 2 && // 1 modified (legacy PM to new price) + 1 deleted (legacy storage) + opts.Items.Count(i => i.Id == "si_premium_legacy" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true) == 1 && // Legacy PM modified + opts.Items.Count(i => i.Deleted == true && i.Id == "si_storage_legacy") == 1)); // Legacy storage deleted } [Theory, BitAutoData] @@ -520,20 +512,19 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); - // Verify that ONLY the premium password manager item is deleted (not other products) - // Note: We delete the specific premium item by ID, so other products are untouched + // Verify that ONLY the premium password manager item is modified (not other products) + // Note: We modify the specific premium item by ID, so other products are untouched await _stripeAdapter.Received(1).UpdateSubscriptionAsync( "sub_123", Arg.Is(opts => - opts.Items.Count == 2 && // 1 deleted (premium password manager) + 1 new seat - opts.Items.Count(i => i.Deleted == true && i.Id == "si_premium") == 1 && // Premium item deleted by ID - opts.Items.Count(i => i.Id == "si_other_product") == 0 && // Other product NOT in update (untouched) - opts.Items.Any(i => i.Price == "teams-seat-annually" && i.Quantity == 1))); + opts.Items.Count == 1 && // Only modify premium password manager item + opts.Items.Count(i => i.Id == "si_premium" && i.Price == "teams-seat-annually" && i.Quantity == 1 && i.Deleted != true) == 1 && // Premium item modified + opts.Items.Count(i => i.Id == "si_other_product") == 0)); // Other product NOT in update (untouched) } [Theory, BitAutoData] @@ -589,7 +580,7 @@ public class UpgradePremiumToOrganizationCommandTests _userService.SaveUserAsync(user).Returns(Task.CompletedTask); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT0); @@ -598,10 +589,8 @@ public class UpgradePremiumToOrganizationCommandTests await _stripeAdapter.Received(1).UpdateSubscriptionAsync( "sub_123", Arg.Is(opts => - opts.Metadata.ContainsKey(StripeConstants.MetadataKeys.PreviousAdditionalStorage) && - opts.Metadata[StripeConstants.MetadataKeys.PreviousAdditionalStorage] == "5" && - opts.Items.Count == 3 && // 2 deleted (premium + storage) + 1 new seat - opts.Items.Count(i => i.Deleted == true) == 2)); + opts.Items.Count == 2 && // 1 modified (premium to new price) + 1 deleted (storage) + opts.Items.Count(i => i.Deleted == true) == 1)); } [Theory, BitAutoData] @@ -636,11 +625,385 @@ public class UpgradePremiumToOrganizationCommandTests _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); // Act - var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually); + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); // Assert Assert.True(result.IsT1); var badRequest = result.AsT1; - Assert.Equal("Premium subscription item not found.", badRequest.Response); + Assert.Equal("Premium subscription password manager item not found.", badRequest.Response); + } + + [Theory, BitAutoData] + public async Task Run_UpdatesCustomerBillingAddress(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + var billingAddress = new Core.Billing.Payment.Models.BillingAddress { Country = "US", PostalCode = "12345" }; + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, billingAddress); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateCustomerAsync( + "cus_123", + Arg.Is(opts => + opts.Address.Country == "US" && + opts.Address.PostalCode == "12345")); + } + + [Theory, BitAutoData] + public async Task Run_EnablesAutomaticTaxOnSubscription(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.AutomaticTax != null && + opts.AutomaticTax.Enabled == true)); + } + + [Theory, BitAutoData] + public async Task Run_UsesAlwaysInvoiceProrationBehavior(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.ProrationBehavior == "always_invoice")); + } + + [Theory, BitAutoData] + public async Task Run_ModifiesExistingSubscriptionItem_NotDeleteAndRecreate(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + // Verify that the subscription item was modified, not deleted + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + // Should have an item with the original ID being modified + opts.Items.Any(item => + item.Id == "si_premium" && + item.Price == "teams-seat-annually" && + item.Quantity == 1 && + item.Deleted != true))); + } + + [Theory, BitAutoData] + public async Task Run_CreatesOrganizationWithCorrectSettings(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _organizationRepository.Received(1).CreateAsync( + Arg.Is(org => + org.Name == "My Organization" && + org.BillingEmail == user.Email && + org.PlanType == PlanType.TeamsAnnually && + org.Seats == 1 && + org.Gateway == GatewayType.Stripe && + org.GatewayCustomerId == "cus_123" && + org.GatewaySubscriptionId == "sub_123" && + org.Enabled == true)); + } + + [Theory, BitAutoData] + public async Task Run_CreatesOrganizationApiKeyWithCorrectType(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _organizationApiKeyRepository.Received(1).CreateAsync( + Arg.Is(apiKey => + apiKey.Type == OrganizationApiKeyType.Default && + !string.IsNullOrEmpty(apiKey.ApiKey))); + } + + [Theory, BitAutoData] + public async Task Run_CreatesOrganizationUserAsOwnerWithAllPermissions(User user) + { + // Arrange + user.Premium = true; + user.GatewaySubscriptionId = "sub_123"; + user.GatewayCustomerId = "cus_123"; + + var mockSubscription = new Subscription + { + Id = "sub_123", + Items = new StripeList + { + Data = new List + { + new SubscriptionItem + { + Id = "si_premium", + Price = new Price { Id = "premium-annually" } + } + } + }, + Metadata = new Dictionary() + }; + + var mockPremiumPlans = CreateTestPremiumPlansList(); + var mockPlan = CreateTestPlan(PlanType.TeamsAnnually, stripeSeatPlanId: "teams-seat-annually"); + + _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(mockSubscription); + _pricingClient.ListPremiumPlans().Returns(mockPremiumPlans); + _pricingClient.GetPlanOrThrow(PlanType.TeamsAnnually).Returns(mockPlan); + _stripeAdapter.UpdateSubscriptionAsync(Arg.Any(), Arg.Any()).Returns(mockSubscription); + _stripeAdapter.UpdateCustomerAsync(Arg.Any(), Arg.Any()).Returns(Task.FromResult(new Customer())); + _organizationRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationApiKeyRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _organizationUserRepository.CreateAsync(Arg.Any()).Returns(callInfo => Task.FromResult(callInfo.Arg())); + _applicationCacheService.UpsertOrganizationAbilityAsync(Arg.Any()).Returns(Task.CompletedTask); + _userService.SaveUserAsync(user).Returns(Task.CompletedTask); + + // Act + var result = await _command.Run(user, "My Organization", "encrypted-key", PlanType.TeamsAnnually, CreateTestBillingAddress()); + + // Assert + Assert.True(result.IsT0); + + await _organizationUserRepository.Received(1).CreateAsync( + Arg.Is(orgUser => + orgUser.UserId == user.Id && + orgUser.Type == OrganizationUserType.Owner && + orgUser.Status == OrganizationUserStatusType.Confirmed)); } } diff --git a/test/Core.Test/Billing/Subscriptions/Entities/SubscriptionDiscountTests.cs b/test/Core.Test/Billing/Subscriptions/Entities/SubscriptionDiscountTests.cs new file mode 100644 index 0000000000..8da3b5ea1d --- /dev/null +++ b/test/Core.Test/Billing/Subscriptions/Entities/SubscriptionDiscountTests.cs @@ -0,0 +1,109 @@ +using System.Text.Json; +using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Subscriptions.Entities; +using Xunit; + +namespace Bit.Core.Test.Billing.Subscriptions.Entities; + +public class SubscriptionDiscountTests +{ + [Fact] + public void StripeProductIds_CanSerializeToJson() + { + // Arrange + var discount = new SubscriptionDiscount + { + StripeCouponId = "test-coupon", + StripeProductIds = new List { "prod_123", "prod_456" }, + Duration = "once", + StartDate = DateTime.UtcNow, + EndDate = DateTime.UtcNow.AddDays(30), + AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions + }; + + // Act + var json = JsonSerializer.Serialize(discount.StripeProductIds); + + // Assert + Assert.Equal("[\"prod_123\",\"prod_456\"]", json); + } + + [Fact] + public void StripeProductIds_CanDeserializeFromJson() + { + // Arrange + var json = "[\"prod_123\",\"prod_456\"]"; + + // Act + var result = JsonSerializer.Deserialize>(json); + + // Assert + Assert.NotNull(result); + Assert.Equal(2, result.Count); + Assert.Contains("prod_123", result); + Assert.Contains("prod_456", result); + } + + [Fact] + public void StripeProductIds_HandlesNull() + { + // Arrange + var discount = new SubscriptionDiscount + { + StripeCouponId = "test-coupon", + StripeProductIds = null, + Duration = "once", + StartDate = DateTime.UtcNow, + EndDate = DateTime.UtcNow.AddDays(30), + AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions + }; + + // Act + var json = JsonSerializer.Serialize(discount.StripeProductIds); + + // Assert + Assert.Equal("null", json); + } + + [Fact] + public void StripeProductIds_HandlesEmptyCollection() + { + // Arrange + var discount = new SubscriptionDiscount + { + StripeCouponId = "test-coupon", + StripeProductIds = new List(), + Duration = "once", + StartDate = DateTime.UtcNow, + EndDate = DateTime.UtcNow.AddDays(30), + AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions + }; + + // Act + var json = JsonSerializer.Serialize(discount.StripeProductIds); + + // Assert + Assert.Equal("[]", json); + } + + [Fact] + public void Validate_RejectsEndDateBeforeStartDate() + { + // Arrange + var discount = new SubscriptionDiscount + { + StripeCouponId = "test-coupon", + Duration = "once", + StartDate = DateTime.UtcNow.AddDays(30), + EndDate = DateTime.UtcNow, // EndDate before StartDate + AudienceType = DiscountAudienceType.UserHasNoPreviousSubscriptions + }; + + // Act + var validationResults = discount.Validate(new System.ComponentModel.DataAnnotations.ValidationContext(discount)).ToList(); + + // Assert + Assert.Single(validationResults); + Assert.Contains("EndDate", validationResults[0].MemberNames); + } +} diff --git a/test/Core.Test/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQueryTests.cs b/test/Core.Test/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQueryTests.cs index a12a0e4cb0..e0a11741b3 100644 --- a/test/Core.Test/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQueryTests.cs +++ b/test/Core.Test/Billing/Subscriptions/Queries/GetBitwardenSubscriptionQueryTests.cs @@ -461,6 +461,77 @@ public class GetBitwardenSubscriptionQueryTests Assert.Equal(PlanCadenceType.Annually, result.Cart.Cadence); } + [Fact] + public async Task Run_UserOnLegacyPricing_ReturnsCostFromPricingService() + { + var user = CreateUser(); + var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: true); + var premiumPlans = CreatePremiumPlans(); + var availablePlan = premiumPlans.First(p => p.Available); + + _stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + + var previewInvoice = CreateInvoicePreview(totalTax: 150); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(previewInvoice); + + var result = await _query.Run(user); + + Assert.NotNull(result); + Assert.Equal(availablePlan.Seat.Price, result.Cart.PasswordManager.Seats.Cost); + Assert.Equal(1.50m, result.Cart.EstimatedTax); + } + + [Fact] + public async Task Run_UserOnLegacyPricing_CallsPreviewInvoiceWithRebuiltSubscription() + { + var user = CreateUser(); + var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: true); + var premiumPlans = CreatePremiumPlans(); + var availablePlan = premiumPlans.First(p => p.Available); + + _stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + + var previewInvoice = CreateInvoicePreview(); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(previewInvoice); + + await _query.Run(user); + + await _stripeAdapter.Received(1).CreateInvoicePreviewAsync( + Arg.Is(opts => + opts.Subscription == null && + opts.AutomaticTax != null && + opts.AutomaticTax.Enabled == true && + opts.SubscriptionDetails != null && + opts.SubscriptionDetails.Items.Any(i => + i.Price == availablePlan.Seat.StripePriceId && + i.Quantity == 1))); + } + + [Fact] + public async Task Run_UserOnCurrentPricing_ReturnsCostFromSubscriptionItem() + { + var user = CreateUser(); + var subscription = CreateSubscription(SubscriptionStatus.Active, legacyPricing: false); + var premiumPlans = CreatePremiumPlans(); + + _stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, Arg.Any()) + .Returns(subscription); + _pricingClient.ListPremiumPlans().Returns(premiumPlans); + _stripeAdapter.CreateInvoicePreviewAsync(Arg.Any()) + .Returns(CreateInvoicePreview()); + + var result = await _query.Run(user); + + Assert.NotNull(result); + Assert.Equal(19.80m, result.Cart.PasswordManager.Seats.Cost); + } + #region Helper Methods private static User CreateUser() @@ -477,11 +548,14 @@ public class GetBitwardenSubscriptionQueryTests private static Subscription CreateSubscription( string status, bool includeStorage = false, + bool legacyPricing = false, DateTime? cancelAt = null, DateTime? canceledAt = null, string collectionMethod = "charge_automatically") { var currentPeriodEnd = DateTime.UtcNow.AddMonths(1); + var seatPriceId = legacyPricing ? "price_legacy_premium_seat" : "price_premium_seat"; + var seatUnitAmount = legacyPricing ? 1000 : 1980; var items = new List { new() @@ -489,8 +563,8 @@ public class GetBitwardenSubscriptionQueryTests Id = "si_premium_seat", Price = new Price { - Id = "price_premium_seat", - UnitAmountDecimal = 1000, + Id = seatPriceId, + UnitAmountDecimal = seatUnitAmount, Product = new Product { Id = "prod_premium_seat" } }, Quantity = 1, @@ -521,6 +595,7 @@ public class GetBitwardenSubscriptionQueryTests Id = "sub_test123", Status = status, Created = DateTime.UtcNow.AddMonths(-1), + AutomaticTax = new SubscriptionAutomaticTax { Enabled = true }, Customer = new Customer { Id = "cus_test123", @@ -548,6 +623,24 @@ public class GetBitwardenSubscriptionQueryTests Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable { StripePriceId = "price_premium_seat", + Price = 19.80m, + Provided = 1 + }, + Storage = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "price_storage", + Price = 4.0m, + Provided = 1 + } + }, + new() + { + Name = "Premium", + Available = false, + LegacyYear = 2024, + Seat = new Bit.Core.Billing.Pricing.Premium.Purchasable + { + StripePriceId = "price_legacy_premium_seat", Price = 10.0m, Provided = 1 }, diff --git a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs index 223047ee07..b4f1fe2d98 100644 --- a/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs +++ b/test/Core.Test/OrganizationFeatures/OrganizationSubscriptionUpdate/UpgradeOrganizationPlanCommandTests.cs @@ -1,4 +1,7 @@ -using Bit.Core.Billing.Enums; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.Billing.Enums; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; using Bit.Core.Exceptions; @@ -9,6 +12,7 @@ using Bit.Core.OrganizationFeatures.OrganizationSubscriptions; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; using Bit.Core.Services; +using Bit.Core.Test.AdminConsole.AutoFixture; using Bit.Core.Test.AutoFixture.OrganizationFixtures; using Bit.Core.Test.Billing.Mocks; using Bit.Test.Common.AutoFixture; @@ -72,8 +76,12 @@ public class UpgradeOrganizationPlanCommandTests [Theory] [FreeOrganizationUpgradeCustomize, BitAutoData] public async Task UpgradePlan_Passes(Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.AdditionalSmSeats = 10; @@ -100,6 +108,7 @@ public class UpgradeOrganizationPlanCommandTests PlanType planType, Organization organization, OrganizationUpgrade organizationUpgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); @@ -116,6 +125,9 @@ public class UpgradeOrganizationPlanCommandTests organizationUpgrade.Plan = planType; sutProvider.GetDependency().GetPlanOrThrow(organizationUpgrade.Plan).Returns(MockPlans.Get(organizationUpgrade.Plan)); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts { @@ -141,15 +153,20 @@ public class UpgradeOrganizationPlanCommandTests [BitAutoData(PlanType.TeamsAnnually)] [BitAutoData(PlanType.TeamsStarter)] public async Task UpgradePlan_SM_Passes(PlanType planType, Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { - sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); upgrade.Plan = planType; sutProvider.GetDependency().GetPlanOrThrow(upgrade.Plan).Returns(MockPlans.Get(upgrade.Plan)); var plan = MockPlans.Get(upgrade.Plan); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); + + sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); upgrade.AdditionalSeats = 15; @@ -180,6 +197,7 @@ public class UpgradeOrganizationPlanCommandTests [BitAutoData(PlanType.TeamsAnnually)] [BitAutoData(PlanType.TeamsStarter)] public async Task UpgradePlan_SM_NotEnoughSmSeats_Throws(PlanType planType, Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { upgrade.Plan = planType; @@ -191,6 +209,10 @@ public class UpgradeOrganizationPlanCommandTests organization.SmSeats = 2; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts @@ -214,7 +236,9 @@ public class UpgradeOrganizationPlanCommandTests [BitAutoData(PlanType.TeamsAnnually, 51)] [BitAutoData(PlanType.TeamsStarter, 51)] public async Task UpgradePlan_SM_NotEnoughServiceAccounts_Throws(PlanType planType, int currentServiceAccounts, - Organization organization, OrganizationUpgrade upgrade, SutProvider sutProvider) + Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, + SutProvider sutProvider) { upgrade.Plan = planType; upgrade.AdditionalSeats = 15; @@ -226,6 +250,10 @@ public class UpgradeOrganizationPlanCommandTests organization.SmServiceAccounts = currentServiceAccounts; sutProvider.GetDependency().GetPlanOrThrow(organization.PlanType).Returns(MockPlans.Get(organization.PlanType)); + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); sutProvider.GetDependency() .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts @@ -251,6 +279,7 @@ public class UpgradeOrganizationPlanCommandTests OrganizationUpgrade upgrade, string newPublicKey, string newPrivateKey, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { organization.PublicKey = null; @@ -262,6 +291,9 @@ public class UpgradeOrganizationPlanCommandTests publicKey: newPublicKey); upgrade.AdditionalSeats = 10; + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); sutProvider.GetDependency() .GetByIdAsync(organization.Id) .Returns(organization); @@ -291,6 +323,7 @@ public class UpgradeOrganizationPlanCommandTests public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotOverwriteWithNull( Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { // Arrange @@ -304,6 +337,9 @@ public class UpgradeOrganizationPlanCommandTests upgrade.Keys = null; upgrade.AdditionalSeats = 10; + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); sutProvider.GetDependency() .GetByIdAsync(organization.Id) .Returns(organization); @@ -333,6 +369,7 @@ public class UpgradeOrganizationPlanCommandTests public async Task UpgradePlan_WhenOrganizationAlreadyHasPublicAndPrivateKeys_DoesNotBackfillWithNewKeys( Organization organization, OrganizationUpgrade upgrade, + [Policy(PolicyType.ResetPassword, false)] PolicyStatus policy, SutProvider sutProvider) { // Arrange @@ -343,6 +380,9 @@ public class UpgradeOrganizationPlanCommandTests organization.PublicKey = existingPublicKey; organization.PrivateKey = existingPrivateKey; + sutProvider.GetDependency() + .RunAsync(Arg.Any(), Arg.Any()) + .Returns(policy); upgrade.Plan = PlanType.TeamsAnnually; upgrade.Keys = new PublicKeyEncryptionKeyPairData( diff --git a/test/Core.Test/OrganizationFeatures/Policies/PolicyQueryTests.cs b/test/Core.Test/OrganizationFeatures/Policies/PolicyQueryTests.cs new file mode 100644 index 0000000000..ac33a5e5a6 --- /dev/null +++ b/test/Core.Test/OrganizationFeatures/Policies/PolicyQueryTests.cs @@ -0,0 +1,55 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Implementations; +using Bit.Core.AdminConsole.Repositories; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using NSubstitute.ReturnsExtensions; +using Xunit; + +namespace Bit.Core.Test.OrganizationFeatures.Policies; + +[SutProviderCustomize] +public class PolicyQueryTests +{ + [Theory, BitAutoData] + public async Task RunAsync_WithExistingPolicy_ReturnsPolicy(SutProvider sutProvider, + Policy policy) + { + // Arrange + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(policy.OrganizationId, policy.Type) + .Returns(policy); + + // Act + var policyData = await sutProvider.Sut.RunAsync(policy.OrganizationId, policy.Type); + + // Assert + Assert.Equal(policy.Data, policyData.Data); + Assert.Equal(policy.Type, policyData.Type); + Assert.Equal(policy.Enabled, policyData.Enabled); + Assert.Equal(policy.OrganizationId, policyData.OrganizationId); + } + + [Theory, BitAutoData] + public async Task RunAsync_WithNonExistentPolicy_ReturnsDefaultDisabledPolicy( + SutProvider sutProvider, + Guid organizationId, + PolicyType policyType) + { + // Arrange + sutProvider.GetDependency() + .GetByOrganizationIdTypeAsync(organizationId, policyType) + .ReturnsNull(); + + // Act + var policyData = await sutProvider.Sut.RunAsync(organizationId, policyType); + + // Assert + Assert.Equal(organizationId, policyData.OrganizationId); + Assert.Equal(policyType, policyData.Type); + Assert.False(policyData.Enabled); + Assert.Null(policyData.Data); + } +} diff --git a/test/Core.Test/Platform/Mail/DomainClaimedEmailRenderTest.cs b/test/Core.Test/Platform/Mail/DomainClaimedEmailRenderTest.cs new file mode 100644 index 0000000000..57e5f43d8a --- /dev/null +++ b/test/Core.Test/Platform/Mail/DomainClaimedEmailRenderTest.cs @@ -0,0 +1,195 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Models.Data.Organizations; +using Bit.Core.Platform.Mail.Delivery; +using Bit.Core.Platform.Mail.Enqueuing; +using Bit.Core.Services.Mail; +using Bit.Core.Settings; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Platform.Mail; + +public class DomainClaimedEmailRenderTest +{ + [Fact] + public async Task RenderDomainClaimedEmail_ToVerifyTemplate() + { + var globalSettings = new GlobalSettings + { + Mail = new GlobalSettings.MailSettings + { + ReplyToEmail = "no-reply@bitwarden.com", + Smtp = new GlobalSettings.MailSettings.SmtpSettings + { + Host = "localhost", + Port = 1025, + StartTls = false, + Ssl = false + } + }, + SiteName = "Bitwarden" + }; + + var mailDeliveryService = Substitute.For(); + var mailEnqueuingService = new BlockingMailEnqueuingService(); + var distributedCache = Substitute.For(); + var logger = Substitute.For>(); + + var mailService = new HandlebarsMailService( + globalSettings, + mailDeliveryService, + mailEnqueuingService, + distributedCache, + logger + ); + + var organization = new Organization + { + Id = Guid.NewGuid(), + Name = "Acme Corporation" + }; + + var testEmails = new List + { + "alice@acme.com", + "bob@acme.com", + "charlie@acme.com" + }; + + var emailList = new ClaimedUserDomainClaimedEmails( + testEmails, + organization, + "acme.com" + ); + + await mailService.SendClaimedDomainUserEmailAsync(emailList); + + await mailDeliveryService.Received(3).SendEmailAsync(Arg.Any()); + + var calls = mailDeliveryService.ReceivedCalls() + .Where(call => call.GetMethodInfo().Name == "SendEmailAsync") + .ToList(); + + Assert.Equal(3, calls.Count); + + foreach (var call in calls) + { + var mailMessage = call.GetArguments()[0] as Bit.Core.Models.Mail.MailMessage; + Assert.NotNull(mailMessage); + + var recipient = mailMessage.ToEmails.First(); + + Assert.Contains("@acme.com", mailMessage.HtmlContent); + Assert.Contains(recipient, mailMessage.HtmlContent); + Assert.DoesNotContain("[at]", mailMessage.HtmlContent); + Assert.DoesNotContain("[dot]", mailMessage.HtmlContent); + } + } + + [Fact(Skip = "For local development - requires MailCatcher at localhost:10250")] + public async Task SendDomainClaimedEmail_ToMailCatcher() + { + var globalSettings = new GlobalSettings + { + Mail = new GlobalSettings.MailSettings + { + ReplyToEmail = "no-reply@bitwarden.com", + Smtp = new GlobalSettings.MailSettings.SmtpSettings + { + Host = "localhost", + Port = 10250, + StartTls = false, + Ssl = false + } + }, + SiteName = "Bitwarden" + }; + + var mailDeliveryLogger = Substitute.For>(); + var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, mailDeliveryLogger); + var mailEnqueuingService = new BlockingMailEnqueuingService(); + var distributedCache = Substitute.For(); + var logger = Substitute.For>(); + + var mailService = new HandlebarsMailService( + globalSettings, + mailDeliveryService, + mailEnqueuingService, + distributedCache, + logger + ); + + var organization = new Organization + { + Id = Guid.NewGuid(), + Name = "Acme Corporation" + }; + + var testEmails = new List + { + "alice@acme.com", + "bob@acme.com" + }; + + var emailList = new ClaimedUserDomainClaimedEmails( + testEmails, + organization, + "acme.com" + ); + + await mailService.SendClaimedDomainUserEmailAsync(emailList); + } + + [Fact(Skip = "This test sends actual emails and is for manual template verification only")] + public async Task RenderDomainClaimedEmail_WithSpecialCharacters() + { + var globalSettings = new GlobalSettings + { + Mail = new GlobalSettings.MailSettings + { + Smtp = new GlobalSettings.MailSettings.SmtpSettings + { + Host = "localhost", + Port = 1025, + StartTls = false, + Ssl = false + } + }, + SiteName = "Bitwarden" + }; + + var mailDeliveryService = Substitute.For(); + var mailEnqueuingService = new BlockingMailEnqueuingService(); + var distributedCache = Substitute.For(); + var logger = Substitute.For>(); + + var mailService = new HandlebarsMailService( + globalSettings, + mailDeliveryService, + mailEnqueuingService, + distributedCache, + logger + ); + + var organization = new Organization + { + Id = Guid.NewGuid(), + Name = "Test Corp & Co." + }; + + var testEmails = new List + { + "test.user+tag@example.com" + }; + + var emailList = new ClaimedUserDomainClaimedEmails( + testEmails, + organization, + "example.com" + ); + + await mailService.SendClaimedDomainUserEmailAsync(emailList); + } +} diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index b98c4580f5..4ff0868c7e 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -254,21 +254,6 @@ public class HandlebarsMailServiceTests } } - [Fact] - public async Task SendSendEmailOtpEmailAsync_SendsEmail() - { - // Arrange - var email = "test@example.com"; - var token = "aToken"; - var subject = string.Format("Your Bitwarden Send verification code is {0}", token); - - // Act - await _sut.SendSendEmailOtpEmailAsync(email, token, subject); - - // Assert - await _mailDeliveryService.Received(1).SendEmailAsync(Arg.Any()); - } - [Fact] public async Task SendIndividualUserWelcomeEmailAsync_SendsCorrectEmail() { diff --git a/test/Core.Test/Services/Implementations/BraintreeServiceTests.cs b/test/Core.Test/Services/Implementations/BraintreeServiceTests.cs new file mode 100644 index 0000000000..ba62c79021 --- /dev/null +++ b/test/Core.Test/Services/Implementations/BraintreeServiceTests.cs @@ -0,0 +1,118 @@ +using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Services; +using Bit.Core.Services; +using Bit.Core.Settings; +using Braintree; +using Braintree.Exceptions; +using Microsoft.Extensions.Logging; +using NSubstitute; +using Xunit; + +using BraintreeService = Bit.Core.Services.Implementations.BraintreeService; +using Customer = Stripe.Customer; + +namespace Bit.Core.Test.Services.Implementations; + +public class BraintreeServiceTests +{ + private readonly ICustomerGateway _customerGateway; + private readonly BraintreeService _sut; + + public BraintreeServiceTests() + { + var braintreeGateway = Substitute.For(); + _customerGateway = Substitute.For(); + braintreeGateway.Customer.Returns(_customerGateway); + + var globalSettings = Substitute.For(); + var logger = Substitute.For>(); + var mailService = Substitute.For(); + var stripeAdapter = Substitute.For(); + + _sut = new BraintreeService( + braintreeGateway, + globalSettings, + logger, + mailService, + stripeAdapter); + } + + #region GetCustomer + + [Fact] + public async Task GetCustomer_NoBraintreeCustomerIdInMetadata_ReturnsNull() + { + // Arrange + var stripeCustomer = new Customer + { + Id = "cus_123", + Metadata = new Dictionary() + }; + + // Act + var result = await _sut.GetCustomer(stripeCustomer); + + // Assert + Assert.Null(result); + await _customerGateway.DidNotReceiveWithAnyArgs().FindAsync(Arg.Any()); + } + + [Fact] + public async Task GetCustomer_BraintreeCustomerFound_ReturnsCustomer() + { + // Arrange + const string braintreeCustomerId = "bt_customer_123"; + + var stripeCustomer = new Customer + { + Id = "cus_123", + Metadata = new Dictionary + { + [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomerId + } + }; + + var braintreeCustomer = Substitute.For(); + + _customerGateway + .FindAsync(braintreeCustomerId) + .Returns(braintreeCustomer); + + // Act + var result = await _sut.GetCustomer(stripeCustomer); + + // Assert + Assert.NotNull(result); + Assert.Same(braintreeCustomer, result); + await _customerGateway.Received(1).FindAsync(braintreeCustomerId); + } + + [Fact] + public async Task GetCustomer_BraintreeCustomerNotFound_LogsWarningAndReturnsNull() + { + // Arrange + const string braintreeCustomerId = "bt_non_existent_customer"; + + var stripeCustomer = new Customer + { + Id = "cus_123", + Metadata = new Dictionary + { + [StripeConstants.MetadataKeys.BraintreeCustomerId] = braintreeCustomerId + } + }; + + _customerGateway + .FindAsync(braintreeCustomerId) + .Returns(_ => throw new NotFoundException()); + + // Act + var result = await _sut.GetCustomer(stripeCustomer); + + // Assert + Assert.Null(result); + await _customerGateway.Received(1).FindAsync(braintreeCustomerId); + } + + #endregion +} diff --git a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs index b92477e73d..f6b1bd200a 100644 --- a/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs +++ b/test/Core.Test/Tools/ImportFeatures/ImportCiphersAsyncCommandTests.cs @@ -135,6 +135,43 @@ public class ImportCiphersAsyncCommandTests Assert.Equal("You cannot import items into your personal vault because you are a member of an organization which forbids it.", exception.Message); } + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_FavoriteCiphers_PersistsFavoriteInfo( + Guid importingUserId, + List ciphers, + SutProvider sutProvider + ) + { + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.PolicyRequirements) + .Returns(true); + + sutProvider.GetDependency() + .GetAsync(importingUserId) + .Returns(new OrganizationDataOwnershipPolicyRequirement( + OrganizationDataOwnershipState.Disabled, + [])); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List(); + var folderRelationships = new List>(); + + ciphers.ForEach(c => + { + c.UserId = importingUserId; + c.Favorite = true; + }); + + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, Arg.Is>(ciphers => ciphers.All(c => c.Favorites == $"{{\"{importingUserId.ToString().ToUpperInvariant()}\":true}}")), Arg.Any>()); + } + [Theory, BitAutoData] public async Task ImportIntoOrganizationalVaultAsync_Success( Organization organization, @@ -289,4 +326,101 @@ public class ImportCiphersAsyncCommandTests await sutProvider.GetDependency().Received(1).PushSyncVaultAsync(importingUserId); } + + [Theory, BitAutoData] + public async Task ImportIntoIndividualVaultAsync_WithArchivedCiphers_PreservesArchiveStatus( + Guid importingUserId, + List ciphers, + SutProvider sutProvider) + { + var archivedDate = DateTime.UtcNow.AddDays(-1); + ciphers[0].UserId = importingUserId; + ciphers[0].ArchivedDate = archivedDate; + + sutProvider.GetDependency() + .AnyPoliciesApplicableToUserAsync(importingUserId, PolicyType.OrganizationDataOwnership) + .Returns(false); + + sutProvider.GetDependency() + .GetManyByUserIdAsync(importingUserId) + .Returns(new List()); + + var folders = new List(); + var folderRelationships = new List>(); + + await sutProvider.Sut.ImportIntoIndividualVaultAsync(folders, ciphers, folderRelationships, importingUserId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync(importingUserId, + Arg.Is>(c => + c[0].Archives != null && + c[0].Archives.Contains(importingUserId.ToString().ToUpperInvariant()) && + c[0].Archives.Contains(archivedDate.ToString("yyyy-MM-ddTHH:mm:ss.fffffffZ"))), + Arg.Any>()); + } + + /* + * Archive functionality is a per-user function. When importing archived ciphers into an organization vault, + * the Archives field should be set for the importing user only. This allows the importing user to see + * items as archived, while other organization members will not see them as archived. + */ + [Theory, BitAutoData] + public async Task ImportIntoOrganizationalVaultAsync_WithArchivedCiphers_SetsArchivesForImportingUserOnly( + Organization organization, + Guid importingUserId, + OrganizationUser importingOrganizationUser, + List collections, + List ciphers, + SutProvider sutProvider) + { + var archivedDate = DateTime.UtcNow.AddDays(-1); + organization.MaxCollections = null; + importingOrganizationUser.OrganizationId = organization.Id; + + foreach (var collection in collections) + { + collection.OrganizationId = organization.Id; + } + + foreach (var cipher in ciphers) + { + cipher.OrganizationId = organization.Id; + } + + ciphers[0].ArchivedDate = archivedDate; + ciphers[0].Archives = null; + + KeyValuePair[] collectionRelationships = { + new(0, 0), + new(1, 1), + new(2, 2) + }; + + sutProvider.GetDependency() + .GetByIdAsync(organization.Id) + .Returns(organization); + + sutProvider.GetDependency() + .GetByOrganizationAsync(organization.Id, importingUserId) + .Returns(importingOrganizationUser); + + sutProvider.GetDependency() + .GetManyByOrganizationIdAsync(organization.Id) + .Returns(new List()); + + await sutProvider.Sut.ImportIntoOrganizationalVaultAsync(collections, ciphers, collectionRelationships, importingUserId); + + await sutProvider.GetDependency() + .Received(1) + .CreateAsync( + Arg.Is>(c => + c[0].ArchivedDate == archivedDate && + c[0].Archives != null && + c[0].Archives.Contains(importingUserId.ToString().ToUpperInvariant()) && + c[0].Archives.Contains(archivedDate.ToString("yyyy-MM-ddTHH:mm:ss.fffffffZ"))), + Arg.Any>(), + Arg.Any>(), + Arg.Any>()); + } } diff --git a/test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs b/test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs index 1ad6a08516..9bebe5560c 100644 --- a/test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs +++ b/test/Core.Test/Tools/Services/NonAnonymousSendCommandTests.cs @@ -11,6 +11,7 @@ using Bit.Core.Tools.Enums; using Bit.Core.Tools.Models.Data; using Bit.Core.Tools.Repositories; using Bit.Core.Tools.SendFeatures.Commands; +using Bit.Core.Tools.SendFeatures.Commands.Interfaces; using Bit.Core.Tools.Services; using Bit.Test.Common.AutoFixture.Attributes; using Microsoft.Extensions.Logging; @@ -28,7 +29,6 @@ public class NonAnonymousSendCommandTests private readonly ISendRepository _sendRepository; private readonly ISendFileStorageService _sendFileStorageService; private readonly IPushNotificationService _pushNotificationService; - private readonly ISendAuthorizationService _sendAuthorizationService; private readonly ISendValidationService _sendValidationService; private readonly IFeatureService _featureService; private readonly ICurrentContext _currentContext; @@ -42,7 +42,6 @@ public class NonAnonymousSendCommandTests _sendRepository = Substitute.For(); _sendFileStorageService = Substitute.For(); _pushNotificationService = Substitute.For(); - _sendAuthorizationService = Substitute.For(); _featureService = Substitute.For(); _sendValidationService = Substitute.For(); _currentContext = Substitute.For(); @@ -53,7 +52,6 @@ public class NonAnonymousSendCommandTests _sendRepository, _sendFileStorageService, _pushNotificationService, - _sendAuthorizationService, _sendValidationService, _sendCoreHelperService, _logger @@ -1093,4 +1091,329 @@ public class NonAnonymousSendCommandTests Assert.Equal("File received does not match expected file length.", exception.Message); } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithTextSend_ThrowsBadRequest() + { + // Arrange + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.Text, + UserId = Guid.NewGuid() + }; + var fileId = "somefile123"; + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => + _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId)); + + Assert.Equal("Can only get a download URL for a file type of Send", exception.Message); + + // Verify no storage service methods were called + await _sendFileStorageService.DidNotReceive() + .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithDisabledSend_ReturnsDenied() + { + // Arrange + var fileId = "file123"; + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.File, + UserId = Guid.NewGuid(), + Disabled = true, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + // Assert + Assert.Null(url); + Assert.Equal(SendAccessResult.Denied, result); + + // Verify no repository updates occurred + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any()); + await _sendFileStorageService.DidNotReceive() + .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithMaxAccessCountReached_ReturnsDenied() + { + // Arrange + var fileId = "file123"; + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.File, + UserId = Guid.NewGuid(), + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 5, + MaxAccessCount = 5 + }; + + // Act + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + // Assert + Assert.Null(url); + Assert.Equal(SendAccessResult.Denied, result); + + // Verify no repository updates occurred + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any()); + await _sendFileStorageService.DidNotReceive() + .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithExpiredSend_ReturnsDenied() + { + // Arrange + var fileId = "file123"; + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.File, + UserId = Guid.NewGuid(), + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = DateTime.UtcNow.AddDays(-1), // Expired yesterday + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + // Assert + Assert.Null(url); + Assert.Equal(SendAccessResult.Denied, result); + + // Verify no repository updates occurred + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any()); + await _sendFileStorageService.DidNotReceive() + .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithDeletionDatePassed_ReturnsDenied() + { + // Arrange + var fileId = "file123"; + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.File, + UserId = Guid.NewGuid(), + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(-1), // Deletion date has passed + ExpirationDate = null, + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + // Assert + Assert.Null(url); + Assert.Equal(SendAccessResult.Denied, result); + + // Verify no repository updates occurred + await _sendRepository.DidNotReceive().ReplaceAsync(Arg.Any()); + await _pushNotificationService.DidNotReceive().PushSyncSendUpdateAsync(Arg.Any()); + await _sendFileStorageService.DidNotReceive() + .GetSendFileDownloadUrlAsync(Arg.Any(), Arg.Any()); + } + + [Fact] + public async Task GetSendFileDownloadUrlAsync_WithValidSend_ReturnsUrlAndIncrementsAccessCount() + { + // Arrange + var fileId = "file123"; + var expectedUrl = "https://download.example.com/file123"; + var send = new Send + { + Id = Guid.NewGuid(), + Type = SendType.File, + UserId = Guid.NewGuid(), + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 3, + MaxAccessCount = 10 + }; + + _sendFileStorageService.GetSendFileDownloadUrlAsync(send, fileId).Returns(expectedUrl); + + // Act + var (url, result) = await _nonAnonymousSendCommand.GetSendFileDownloadUrlAsync(send, fileId); + + // Assert + Assert.Equal(expectedUrl, url); + Assert.Equal(SendAccessResult.Granted, result); + + // Verify access count was incremented + Assert.Equal(4, send.AccessCount); + + // Verify repository was updated + await _sendRepository.Received(1).ReplaceAsync(send); + await _pushNotificationService.Received(1).PushSyncSendUpdateAsync(send); + + // Verify file storage service was called + await _sendFileStorageService.Received(1).GetSendFileDownloadUrlAsync(send, fileId); + } + + [Fact] + public void SendCanBeAccessed_WithDisabledSend_ReturnsFalse() + { + // Arrange + var send = new Send + { + Disabled = true, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.False(result); + } + + [Fact] + public void SendCanBeAccessed_WithMaxAccessCountReached_ReturnsFalse() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 10, + MaxAccessCount = 10 + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.False(result); + } + + [Fact] + public void SendCanBeAccessed_WithExpiredSend_ReturnsFalse() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = DateTime.UtcNow.AddDays(-1), + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.False(result); + } + + [Fact] + public void SendCanBeAccessed_WithDeletionDatePassed_ReturnsFalse() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(-1), + ExpirationDate = null, + AccessCount = 0, + MaxAccessCount = null + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.False(result); + } + + [Fact] + public void SendCanBeAccessed_WithValidSend_ReturnsTrue() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = DateTime.UtcNow.AddDays(7), + AccessCount = 5, + MaxAccessCount = 10 + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.True(result); + } + + [Fact] + public void SendCanBeAccessed_WithNullMaxAccessCount_ReturnsTrue() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 100, + MaxAccessCount = null + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.True(result); + } + + [Fact] + public void SendCanBeAccessed_WithNullExpirationDate_ReturnsTrue() + { + // Arrange + var send = new Send + { + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null, + AccessCount = 0, + MaxAccessCount = 10 + }; + + // Act + var result = INonAnonymousSendCommand.SendCanBeAccessed(send); + + // Assert + Assert.True(result); + } } diff --git a/test/Core.Test/Tools/Services/SendAuthenticationQueryTests.cs b/test/Core.Test/Tools/Services/SendAuthenticationQueryTests.cs index 7901b3c5c0..87880998c3 100644 --- a/test/Core.Test/Tools/Services/SendAuthenticationQueryTests.cs +++ b/test/Core.Test/Tools/Services/SendAuthenticationQueryTests.cs @@ -43,7 +43,7 @@ public class SendAuthenticationQueryTests } [Theory] - [MemberData(nameof(EmailParsingTestCases))] + [MemberData(nameof(EmailsParsingTestCases))] public async Task GetAuthenticationMethod_WithEmails_ParsesEmailsCorrectly(string emailString, string[] expectedEmails) { // Arrange @@ -56,7 +56,7 @@ public class SendAuthenticationQueryTests // Assert var emailOtp = Assert.IsType(result); - Assert.Equal(expectedEmails, emailOtp.Emails); + Assert.Equal(expectedEmails, emailOtp.emails); } [Fact] @@ -64,7 +64,7 @@ public class SendAuthenticationQueryTests { // Arrange var sendId = Guid.NewGuid(); - var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: "hashedpassword", AuthType.Email); + var send = CreateSend(accessCount: 0, maxAccessCount: 10, emails: "person@company.com", password: "hashedpassword", AuthType.Email); _sendRepository.GetByIdAsync(sendId).Returns(send); // Act @@ -108,18 +108,201 @@ public class SendAuthenticationQueryTests yield return new object[] { null, typeof(NeverAuthenticate) }; yield return new object[] { CreateSend(accessCount: 5, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) }; yield return new object[] { CreateSend(accessCount: 6, maxAccessCount: 5, emails: null, password: null, AuthType.None), typeof(NeverAuthenticate) }; - yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "test@example.com", password: null, AuthType.Email), typeof(EmailOtp) }; + yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: "person@company.com", password: null, AuthType.Email), typeof(EmailOtp) }; yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: "hashedpassword", AuthType.Password), typeof(ResourcePassword) }; yield return new object[] { CreateSend(accessCount: 0, maxAccessCount: 10, emails: null, password: null, AuthType.None), typeof(NotAuthenticated) }; } - public static IEnumerable EmailParsingTestCases() + [Fact] + public async Task GetAuthenticationMethod_WithDisabledSend_ReturnsNeverAuthenticate() { - yield return new object[] { "test@example.com", new[] { "test@example.com" } }; - yield return new object[] { "test1@example.com,test2@example.com", new[] { "test1@example.com", "test2@example.com" } }; - yield return new object[] { " test@example.com , other@example.com ", new[] { "test@example.com", "other@example.com" } }; - yield return new object[] { "test@example.com,,other@example.com", new[] { "test@example.com", "other@example.com" } }; - yield return new object[] { " , test@example.com, ,other@example.com, ", new[] { "test@example.com", "other@example.com" } }; + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 0, + MaxAccessCount = 10, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = true, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithExpiredSend_ReturnsNeverAuthenticate() + { + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 0, + MaxAccessCount = 10, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = DateTime.UtcNow.AddDays(-1) // Expired yesterday + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithDeletionDatePassed_ReturnsNeverAuthenticate() + { + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 0, + MaxAccessCount = 10, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(-1), // Should have been deleted yesterday + ExpirationDate = null + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithDeletionDateEqualToNow_ReturnsNeverAuthenticate() + { + // Arrange + var sendId = Guid.NewGuid(); + var now = DateTime.UtcNow; + var send = new Send + { + Id = sendId, + AccessCount = 0, + MaxAccessCount = 10, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = now, // DeletionDate <= DateTime.UtcNow + ExpirationDate = null + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithAccessCountEqualToMaxAccessCount_ReturnsNeverAuthenticate() + { + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 5, + MaxAccessCount = 5, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithNullMaxAccessCount_DoesNotRestrictAccess() + { + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 1000, + MaxAccessCount = null, // No limit + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + [Fact] + public async Task GetAuthenticationMethod_WithNullExpirationDate_DoesNotExpire() + { + // Arrange + var sendId = Guid.NewGuid(); + var send = new Send + { + Id = sendId, + AccessCount = 0, + MaxAccessCount = 10, + Emails = "person@company.com", + Password = null, + AuthType = AuthType.Email, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null // No expiration + }; + _sendRepository.GetByIdAsync(sendId).Returns(send); + + // Act + var result = await _sendAuthenticationQuery.GetAuthenticationMethod(sendId); + + // Assert + Assert.IsType(result); + } + + public static IEnumerable EmailsParsingTestCases() + { + yield return new object[] { "person@company.com", new[] { "person@company.com" } }; + yield return new object[] { "person1@company.com,person2@company.com", new[] { "person1@company.com", "person2@company.com" } }; + yield return new object[] { " person1@company.com , person2@company.com ", new[] { "person1@company.com", "person2@company.com" } }; + yield return new object[] { "person1@company.com,,person2@company.com", new[] { "person1@company.com", "person2@company.com" } }; + yield return new object[] { " , person1@company.com, ,person2@company.com, ", new[] { "person1@company.com", "person2@company.com" } }; } private static Send CreateSend(int accessCount, int? maxAccessCount, string? emails, string? password, AuthType? authType) @@ -131,7 +314,10 @@ public class SendAuthenticationQueryTests MaxAccessCount = maxAccessCount, Emails = emails, Password = password, - AuthType = authType + AuthType = authType, + Disabled = false, + DeletionDate = DateTime.UtcNow.AddDays(7), + ExpirationDate = null }; } } diff --git a/test/Core.Test/Tools/Services/SendValidationServiceTests.cs b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs new file mode 100644 index 0000000000..8adce1a29f --- /dev/null +++ b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs @@ -0,0 +1,120 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Tools.Entities; +using Bit.Core.Tools.Enums; +using Bit.Core.Tools.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Tools.Services; + +[SutProviderCustomize] +public class SendValidationServiceTests +{ + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgGrantedPremiumUser_UsesPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.Storage = 1024L * 1024L * 1024L; // 1 GB used + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = false; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + var premiumPlan = new Plan + { + Storage = new Purchasable { Provided = 5 } + }; + sutProvider.GetDependency().GetAvailablePremiumPlan().Returns(premiumPlan); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert + await sutProvider.GetDependency().Received(1).GetAvailablePremiumPlan(); + Assert.True(result > 0); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_IndividualPremium_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = true; + user.MaxStorageGb = 10; + user.EmailVerified = true; + + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for individual premium users + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_SelfHosted_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = true; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for self-hosted + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgSend_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + Organization org) + { + // Arrange + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + org.MaxStorageGb = 100; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for org sends + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } +} diff --git a/test/Core.Test/Utilities/DomainNameAttributeTests.cs b/test/Core.Test/Utilities/DomainNameAttributeTests.cs new file mode 100644 index 0000000000..3f3190c9a1 --- /dev/null +++ b/test/Core.Test/Utilities/DomainNameAttributeTests.cs @@ -0,0 +1,84 @@ +using Bit.Core.Utilities; +using Xunit; + +namespace Bit.Core.Test.Utilities; + +public class DomainNameValidatorAttributeTests +{ + [Theory] + [InlineData("example.com")] // basic domain + [InlineData("sub.example.com")] // subdomain + [InlineData("sub.sub2.example.com")] // multiple subdomains + [InlineData("example-dash.com")] // domain with dash + [InlineData("123example.com")] // domain starting with number + [InlineData("example123.com")] // domain with numbers + [InlineData("e.com")] // short domain + [InlineData("very-long-subdomain-name.example.com")] // long subdomain + [InlineData("wörldé.com")] // unicode domain (IDN) + public void IsValid_ReturnsTrueWhenValid(string domainName) + { + var sut = new DomainNameValidatorAttribute(); + + var actual = sut.IsValid(domainName); + + Assert.True(actual); + } + + [Theory] + [InlineData("")] // XSS attempt + [InlineData("example.com
+ - +
-

- © 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa +

+ © {{ CurrentYear }} Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this