From 51d90cce3de3bc11bf0b71db91df720ecf622207 Mon Sep 17 00:00:00 2001 From: mkincaid-bw Date: Thu, 15 Jan 2026 13:43:23 -0800 Subject: [PATCH 01/12] Add Entity Framework migration validation to verify_migrations script (#6817) * Add Entity Framework migration validation to verify_migrations script Enhances dev/verify_migrations.ps1 to validate EF migration files in addition to SQL migrations. The script now validates migrations in util/MySqlMigrations, util/PostgresMigrations, and util/SqliteMigrations directories. Validation includes: - Correct naming format (YYYYMMDDHHMMSS_Description.cs) - Both .cs and .Designer.cs files exist as pairs - Chronological ordering of timestamps - Excludes DatabaseContextModelSnapshot.cs files The script provides comprehensive reporting for all migration types with a summary showing which validations passed or failed. Co-Authored-By: Claude Sonnet 4.5 * Fix: Validate all EF migration files instead of silently ignoring malformed names Previously, migration files that didn't match the expected pattern were silently ignored during validation. This could allow incorrectly named files to slip through. Now the script explicitly tracks and reports any migration files that don't match the required YYYYMMDDHHMMSS_Description.cs format, ensuring all new migration files are properly validated. Addresses feedback from PR review to prevent malformed migration files from being overlooked. Co-Authored-By: Claude Sonnet 4.5 --------- Co-authored-by: Claude Sonnet 4.5 --- dev/verify_migrations.ps1 | 320 ++++++++++++++++++++++++++++++++------ 1 file changed, 270 insertions(+), 50 deletions(-) diff --git a/dev/verify_migrations.ps1 b/dev/verify_migrations.ps1 index ad0d34cef1..ce1754e684 100644 --- a/dev/verify_migrations.ps1 +++ b/dev/verify_migrations.ps1 @@ -5,12 +5,19 @@ Validates that new database migration files follow naming conventions and chronological order. .DESCRIPTION - This script validates migration files in util/Migrator/DbScripts/ to ensure: + This script validates migration files to ensure: + + For SQL migrations in util/Migrator/DbScripts/: 1. New migrations follow the naming format: YYYY-MM-DD_NN_Description.sql 2. New migrations are chronologically ordered (filename sorts after existing migrations) 3. Dates use leading zeros (e.g., 2025-01-05, not 2025-1-5) 4. A 2-digit sequence number is included (e.g., _00, _01) + For Entity Framework migrations in util/MySqlMigrations, util/PostgresMigrations, util/SqliteMigrations: + 1. New migrations follow the naming format: YYYYMMDDHHMMSS_Description.cs + 2. Each migration has both .cs and .Designer.cs files + 3. New migrations are chronologically ordered (timestamp sorts after existing migrations) + .PARAMETER BaseRef The base git reference to compare against (e.g., 'main', 'HEAD~1') @@ -58,75 +65,288 @@ $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$migrationPath/" # Find added migrations $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } +$sqlValidationFailed = $false + if ($addedMigrations.Count -eq 0) { - Write-Host "No new migration files added." - exit 0 + Write-Host "No new SQL migration files added." + Write-Host "" +} +else { + Write-Host "New SQL migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous SQL migrations found (initial commit?). Skipping chronological validation." + Write-Host "" + } + else { + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last SQL migration in base reference: $lastBaseMigration" + Write-Host "" + + # Required format regex: YYYY-MM-DD_NN_Description.sql + $formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Validate NEW migration filename format + if ($migrationName -notmatch $formatRegex) { + Write-Host "ERROR: Migration '$migrationName' does not match required format" + Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" + Write-Host " - YYYY: 4-digit year" + Write-Host " - MM: 2-digit month with leading zero (01-12)" + Write-Host " - DD: 2-digit day with leading zero (01-31)" + Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" + Write-Host "Example: 2025-01-15_00_MyMigration.sql" + $sqlValidationFailed = $true + continue + } + + # Compare migration name with last base migration (using ordinal string comparison) + if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" + $sqlValidationFailed = $true + } + else { + Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + } + } + + Write-Host "" + } + + if ($sqlValidationFailed) { + Write-Host "FAILED: One or more SQL migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new SQL migration files must:" + Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" + Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" + Write-Host " 4. Have a filename that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" + Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 3. Ensure the date is after $lastBaseMigration" + Write-Host "" + Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + } + else { + Write-Host "SUCCESS: All new SQL migrations are correctly named and in chronological order" + } + + Write-Host "" } -Write-Host "New migration files detected:" -$addedMigrations | ForEach-Object { Write-Host " $_" } +# =========================================================================================== +# Validate Entity Framework Migrations +# =========================================================================================== + +Write-Host "===================================================================" +Write-Host "Validating Entity Framework Migrations" +Write-Host "===================================================================" Write-Host "" -# Get the last migration from base reference -if ($baseMigrations.Count -eq 0) { - Write-Host "No previous migrations found (initial commit?). Skipping validation." - exit 0 -} +$efMigrationPaths = @( + @{ Path = "util/MySqlMigrations/Migrations"; Name = "MySQL" }, + @{ Path = "util/PostgresMigrations/Migrations"; Name = "Postgres" }, + @{ Path = "util/SqliteMigrations/Migrations"; Name = "SQLite" } +) -$lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) -Write-Host "Last migration in base reference: $lastBaseMigration" -Write-Host "" +$efValidationFailed = $false -# Required format regex: YYYY-MM-DD_NN_Description.sql -$formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' +foreach ($migrationPathInfo in $efMigrationPaths) { + $efPath = $migrationPathInfo.Path + $dbName = $migrationPathInfo.Name -$validationFailed = $false + Write-Host "-------------------------------------------------------------------" + Write-Host "Checking $dbName EF migrations in $efPath" + Write-Host "-------------------------------------------------------------------" + Write-Host "" -foreach ($migration in $addedMigrations) { - $migrationName = Split-Path -Leaf $migration + # Get list of migrations from base reference + try { + $baseMigrations = git ls-tree -r --name-only $BaseRef -- "$efPath/" 2>$null | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } + } + catch { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } - # Validate NEW migration filename format - if ($migrationName -notmatch $formatRegex) { - Write-Host "ERROR: Migration '$migrationName' does not match required format" - Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" - Write-Host " - YYYY: 4-digit year" - Write-Host " - MM: 2-digit month with leading zero (01-12)" - Write-Host " - DD: 2-digit day with leading zero (01-31)" - Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" - Write-Host "Example: 2025-01-15_00_MyMigration.sql" - $validationFailed = $true + # Get list of migrations from current reference + $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$efPath/" | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + + # Find added migrations + $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } + + if ($addedMigrations.Count -eq 0) { + Write-Host "No new $dbName EF migration files added." + Write-Host "" continue } - # Compare migration name with last base migration (using ordinal string comparison) - if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { - Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" - $validationFailed = $true + Write-Host "New $dbName EF migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous $dbName migrations found. Skipping chronological validation." + Write-Host "" } else { - Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last $dbName migration in base reference: $lastBaseMigration" + Write-Host "" } + + # Required format regex: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs + $efFormatRegex = '^[0-9]{14}_.+\.cs$' + + # Group migrations by base name (without .Designer.cs suffix) + $migrationGroups = @{} + $unmatchedFiles = @() + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Extract base name (remove .Designer.cs or .cs) + if ($migrationName -match '^([0-9]{14}_.+?)(?:\.Designer)?\.cs$') { + $baseName = $matches[1] + if (-not $migrationGroups.ContainsKey($baseName)) { + $migrationGroups[$baseName] = @() + } + $migrationGroups[$baseName] += $migrationName + } + else { + # Track files that don't match the expected pattern + $unmatchedFiles += $migrationName + } + } + + # Flag any files that don't match the expected pattern + if ($unmatchedFiles.Count -gt 0) { + Write-Host "ERROR: The following migration files do not match the required format:" + foreach ($unmatchedFile in $unmatchedFiles) { + Write-Host " - $unmatchedFile" + } + Write-Host "" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host " - Description: Descriptive name using PascalCase" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" + Write-Host "" + $efValidationFailed = $true + } + + foreach ($baseName in $migrationGroups.Keys | Sort-Object) { + $files = $migrationGroups[$baseName] + + # Validate format + $mainFile = "$baseName.cs" + $designerFile = "$baseName.Designer.cs" + + if ($mainFile -notmatch $efFormatRegex) { + Write-Host "ERROR: Migration '$mainFile' does not match required format" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host "Example: 20250115120000_AddNewFeature.cs" + $efValidationFailed = $true + continue + } + + # Check that both .cs and .Designer.cs files exist + $hasCsFile = $files -contains $mainFile + $hasDesignerFile = $files -contains $designerFile + + if (-not $hasCsFile) { + Write-Host "ERROR: Missing main migration file: $mainFile" + $efValidationFailed = $true + } + + if (-not $hasDesignerFile) { + Write-Host "ERROR: Missing designer file: $designerFile" + Write-Host "Each EF migration must have both a .cs and .Designer.cs file" + $efValidationFailed = $true + } + + if (-not $hasCsFile -or -not $hasDesignerFile) { + continue + } + + # Compare migration timestamp with last base migration (using ordinal string comparison) + if ($baseMigrations.Count -gt 0) { + if ([string]::CompareOrdinal($mainFile, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$mainFile' is not chronologically after '$lastBaseMigration'" + $efValidationFailed = $true + } + else { + Write-Host "OK: '$mainFile' is chronologically after '$lastBaseMigration'" + } + } + else { + Write-Host "OK: '$mainFile' (no previous migrations to compare)" + } + } + + Write-Host "" +} + +if ($efValidationFailed) { + Write-Host "FAILED: One or more EF migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new EF migration files must:" + Write-Host " 1. Follow the naming format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 2. Include both .cs and .Designer.cs files" + Write-Host " 3. Have a timestamp that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in the respective Migrations directory" + Write-Host " 2. Ensure both .cs and .Designer.cs files exist" + Write-Host " 3. Rename to follow format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 4. Ensure the timestamp is after the last migration" + Write-Host "" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" +} +else { + Write-Host "SUCCESS: All new EF migrations are correctly named and in chronological order" } Write-Host "" +Write-Host "===================================================================" +Write-Host "Validation Summary" +Write-Host "===================================================================" + +if ($sqlValidationFailed -or $efValidationFailed) { + if ($sqlValidationFailed) { + Write-Host "❌ SQL migrations validation FAILED" + } + else { + Write-Host "✓ SQL migrations validation PASSED" + } + + if ($efValidationFailed) { + Write-Host "❌ EF migrations validation FAILED" + } + else { + Write-Host "✓ EF migrations validation PASSED" + } -if ($validationFailed) { - Write-Host "FAILED: One or more migrations are incorrectly named or not in chronological order" Write-Host "" - Write-Host "All new migration files must:" - Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" - Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" - Write-Host " 4. Have a filename that sorts after the last migration in base" - Write-Host "" - Write-Host "To fix this issue:" - Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" - Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 3. Ensure the date is after $lastBaseMigration" - Write-Host "" - Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + Write-Host "OVERALL RESULT: FAILED" exit 1 } - -Write-Host "SUCCESS: All new migrations are correctly named and in chronological order" -exit 0 +else { + Write-Host "✓ SQL migrations validation PASSED" + Write-Host "✓ EF migrations validation PASSED" + Write-Host "" + Write-Host "OVERALL RESULT: SUCCESS" + exit 0 +} From ebb0712e335385b25a191e3fa4574db9e8caf410 Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Fri, 16 Jan 2026 08:49:25 +1000 Subject: [PATCH 02/12] [PM-28555] Add idempotent sproc to create My Items collections (#6801) * Add sproc to create multiple default collections. SqlBulkCopy implementation is overkill for most cases. This provides a lighter weight sproc implementation for smaller data sets. * DRY up collection arrangement * DRY up tests because bulk and non-bulk share same behavior * use EF native AddRange instead of bulk insert, because we expect smaller data sizes on self-host --- .../Collections/CollectionUtils.cs | 53 ++++++++++++ ...maticallyConfirmOrganizationUserCommand.cs | 19 +---- .../ConfirmOrganizationUserCommand.cs | 22 ++--- ...rganizationDataOwnershipPolicyValidator.cs | 5 +- .../Repositories/ICollectionRepository.cs | 17 +++- src/Infrastructure.Dapper/DapperHelpers.cs | 15 ++++ .../Repositories/CollectionRepository.cs | 82 ++++++++++--------- .../Repositories/CollectionRepository.cs | 49 ++--------- .../Repositories/DatabaseContext.cs | 2 - .../Collection_CreateDefaultCollections.sql | 69 ++++++++++++++++ .../AutomaticallyConfirmUsersCommandTests.cs | 21 ++--- .../ConfirmOrganizationUserCommandTests.cs | 29 ++----- ...zationDataOwnershipPolicyValidatorTests.cs | 24 +++--- .../CreateDefaultCollectionsBulkTests.cs | 53 ++++++++++++ ...=> CreateDefaultCollectionsSharedTests.cs} | 62 +++++++------- .../CreateDefaultCollectionsTests.cs | 52 ++++++++++++ ...00_Collection_CreateDefaultCollections.sql | 70 ++++++++++++++++ 17 files changed, 449 insertions(+), 195 deletions(-) create mode 100644 src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs create mode 100644 src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql create mode 100644 test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs rename test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/{UpsertDefaultCollectionsTests.cs => CreateDefaultCollectionsSharedTests.cs} (69%) create mode 100644 test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs create mode 100644 util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql diff --git a/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs new file mode 100644 index 0000000000..116992146f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs @@ -0,0 +1,53 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Collections; + +public static class CollectionUtils +{ + /// + /// Arranges Collection and CollectionUser objects to create default user collections. + /// + /// The organization ID. + /// The IDs for organization users who need default collections. + /// The encrypted string to use as the default collection name. + /// A tuple containing the collections and collection users. + public static (ICollection collections, ICollection collectionUsers) + BuildDefaultUserCollections(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) + { + var now = DateTime.UtcNow; + + var collectionUsers = new List(); + var collections = new List(); + + foreach (var orgUserId in organizationUserIds) + { + var collectionId = CoreHelpers.GenerateComb(); + + collections.Add(new Collection + { + Id = collectionId, + OrganizationId = organizationId, + Name = defaultCollectionName, + CreationDate = now, + RevisionDate = now, + Type = CollectionType.DefaultUserCollection, + DefaultUserCollectionEmail = null + + }); + + collectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = orgUserId, + ReadOnly = false, + HidePasswords = false, + Manage = true, + }); + } + + return (collections, collectionUsers); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs index 1b488677ae..0292381857 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs @@ -4,9 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.OrganizationConfirmation; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -83,19 +81,10 @@ public class AutomaticallyConfirmOrganizationUserCommand(IOrganizationUserReposi return; } - await collectionRepository.CreateAsync( - new Collection - { - OrganizationId = request.Organization!.Id, - Name = request.DefaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }, - groups: null, - [new CollectionAccessSelection - { - Id = request.OrganizationUser!.Id, - Manage = true - }]); + await collectionRepository.CreateDefaultCollectionsAsync( + request.Organization!.Id, + [request.OrganizationUser!.Id], + request.DefaultUserCollectionName); } catch (Exception ex) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 0b82ac7ea4..02f3346ba6 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -14,7 +14,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -294,21 +293,10 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var defaultCollection = new Collection - { - OrganizationId = organizationUser.OrganizationId, - Name = defaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }; - var collectionUser = new CollectionAccessSelection - { - Id = organizationUser.Id, - ReadOnly = false, - HidePasswords = false, - Manage = true - }; - - await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]); + await _collectionRepository.CreateDefaultCollectionsAsync( + organizationUser.OrganizationId, + [organizationUser.Id], + defaultUserCollectionName); } /// @@ -339,7 +327,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - await _collectionRepository.UpsertDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); + await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); } /// diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index 7a47baa65a..104a5751ff 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -57,14 +57,15 @@ public class OrganizationDataOwnershipPolicyValidator( var userOrgIds = requirements .Select(requirement => requirement.GetDefaultCollectionRequestOnPolicyEnable(policyUpdate.OrganizationId)) .Where(request => request.ShouldCreateDefaultCollection) - .Select(request => request.OrganizationUserId); + .Select(request => request.OrganizationUserId) + .ToList(); if (!userOrgIds.Any()) { return; } - await collectionRepository.UpsertDefaultCollectionsAsync( + await collectionRepository.CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, userOrgIds, defaultCollectionName); diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index f86147ca7d..3f3b71d2d5 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -64,11 +64,22 @@ public interface ICollectionRepository : IRepository IEnumerable users, IEnumerable groups); /// - /// Creates default user collections for the specified organization users if they do not already have one. + /// Creates default user collections for the specified organization users. + /// Filters internally to only create collections for users who don't already have one. /// /// The Organization ID. /// The Organization User IDs to create default collections for. /// The encrypted string to use as the default collection name. - /// - Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + + /// + /// Creates default user collections for the specified organization users using bulk insert operations. + /// Use this if you need to create collections for > ~1k users. + /// Filters internally to only create collections for users who don't already have one. + /// + /// The Organization ID. + /// The Organization User IDs to create default collections for. + /// The encrypted string to use as the default collection name. + Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + } diff --git a/src/Infrastructure.Dapper/DapperHelpers.cs b/src/Infrastructure.Dapper/DapperHelpers.cs index 9a119e1e32..4384a6f752 100644 --- a/src/Infrastructure.Dapper/DapperHelpers.cs +++ b/src/Infrastructure.Dapper/DapperHelpers.cs @@ -160,6 +160,21 @@ public static class DapperHelpers return ids.ToArrayTVP("GuidId"); } + public static DataTable ToTwoGuidIdArrayTVP(this IEnumerable<(Guid id1, Guid id2)> values) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[TwoGuidIdArray]"); + table.Columns.Add("Id1", typeof(Guid)); + table.Columns.Add("Id2", typeof(Guid)); + + foreach (var value in values) + { + table.Rows.Add(value.id1, value.id2); + } + + return table; + } + public static DataTable ToArrayTVP(this IEnumerable values, string columnName) { var table = new DataTable(); diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 9985b41d56..1531703427 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -1,6 +1,7 @@ using System.Data; using System.Diagnostics.CodeAnalysis; using System.Text.Json; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; @@ -360,7 +361,45 @@ public class CollectionRepository : Repository, ICollectionRep } } - public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + { + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) + { + return; + } + + var organizationUserCollectionIds = organizationUserIds + .Select(ou => (ou, CoreHelpers.GenerateComb())) + .ToTwoGuidIdArrayTVP(); + + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + await using var transaction = connection.BeginTransaction(); + + try + { + await connection.ExecuteAsync( + "[dbo].[Collection_CreateDefaultCollections]", + new + { + OrganizationId = organizationId, + DefaultCollectionName = defaultCollectionName, + OrganizationUserCollectionIds = organizationUserCollectionIds + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + + await transaction.CommitAsync(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } + } + + public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -377,7 +416,8 @@ public class CollectionRepository : Repository, ICollectionRep var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); - var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + var (collections, collectionUsers) = + CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); if (!collectionUsers.Any() || !collections.Any()) { @@ -387,11 +427,11 @@ public class CollectionRepository : Repository, ICollectionRep await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); - transaction.Commit(); + await transaction.CommitAsync(); } catch { - transaction.Rollback(); + await transaction.RollbackAsync(); throw; } } @@ -421,40 +461,6 @@ public class CollectionRepository : Repository, ICollectionRep return organizationUserIds.ToHashSet(); } - private (List collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } - public class CollectionWithGroupsAndUsers : Collection { public CollectionWithGroupsAndUsers() { } diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 5aa156d1f8..74150246b1 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -1,11 +1,10 @@ using AutoMapper; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Utilities; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; @@ -794,7 +793,7 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -808,15 +807,15 @@ public class CollectionRepository : Repository>(collections)); + await dbContext.CollectionUsers.AddRangeAsync(Mapper.Map>(collectionUsers)); await dbContext.SaveChangesAsync(); } @@ -844,37 +843,7 @@ public class CollectionRepository : Repository collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } + public Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) => + CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 7b67a63912..a0ee0376c0 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -17,8 +17,6 @@ using Microsoft.EntityFrameworkCore.Storage.ValueConversion; using DP = Microsoft.AspNetCore.DataProtection; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Repositories; public class DatabaseContext : DbContext diff --git a/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql new file mode 100644 index 0000000000..4e671bd1e4 --- /dev/null +++ b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql @@ -0,0 +1,69 @@ +-- Creates default user collections for organization users +-- Filters out existing default collections at database level +CREATE PROCEDURE [dbo].[Collection_CreateDefaultCollections] + @OrganizationId UNIQUEIDENTIFIER, + @DefaultCollectionName VARCHAR(MAX), + @OrganizationUserCollectionIds AS [dbo].[TwoGuidIdArray] READONLY -- OrganizationUserId, CollectionId +AS +BEGIN + SET NOCOUNT ON + + DECLARE @Now DATETIME2(7) = GETUTCDATE() + + -- Filter to only users who don't have default collections + SELECT ids.Id1, ids.Id2 + INTO #FilteredIds + FROM @OrganizationUserCollectionIds ids + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionUser] cu + INNER JOIN [dbo].[Collection] c ON c.Id = cu.CollectionId + WHERE c.OrganizationId = @OrganizationId + AND c.[Type] = 1 -- CollectionType.DefaultUserCollection + AND cu.OrganizationUserId = ids.Id1 + ); + + -- Insert collections only for users who don't have default collections yet + INSERT INTO [dbo].[Collection] + ( + [Id], + [OrganizationId], + [Name], + [CreationDate], + [RevisionDate], + [Type], + [ExternalId], + [DefaultUserCollectionEmail] + ) + SELECT + ids.Id2, -- CollectionId + @OrganizationId, + @DefaultCollectionName, + @Now, + @Now, + 1, -- CollectionType.DefaultUserCollection + NULL, + NULL + FROM + #FilteredIds ids; + + -- Insert collection user mappings + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + ids.Id2, -- CollectionId + ids.Id1, -- OrganizationUserId + 0, -- ReadOnly = false + 0, -- HidePasswords = false + 1 -- Manage = true + FROM + #FilteredIds ids; + + DROP TABLE #FilteredIds; +END diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs index 180750a9d0..252fb89c87 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs @@ -10,7 +10,6 @@ using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -204,14 +203,10 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.OrganizationId == organization.Id && - c.Name == defaultCollectionName && - c.Type == CollectionType.DefaultUserCollection), - Arg.Is>(groups => groups == null), - Arg.Is>(access => - access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == organizationUser.Id), + defaultCollectionName); } [Theory] @@ -253,9 +248,7 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .DidNotReceive() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory] @@ -291,9 +284,7 @@ public class AutomaticallyConfirmUsersCommandTests var collectionException = new Exception("Collection creation failed"); sutProvider.GetDependency() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()) + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()) .ThrowsAsync(collectionException); // Act diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index 65359b8304..6643f26eb5 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -13,7 +13,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Platform.Push; using Bit.Core.Repositories; @@ -493,15 +492,10 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.Name == collectionName && - c.OrganizationId == organization.Id && - c.Type == CollectionType.DefaultUserCollection), - Arg.Any>(), - Arg.Is>(cu => - cu.Single().Id == orgUser.Id && - cu.Single().Manage)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == orgUser.Id), + collectionName); } [Theory, BitAutoData] @@ -522,7 +516,7 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -539,24 +533,15 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - var policyDetails = new PolicyDetails - { - OrganizationId = org.Id, - OrganizationUserId = orgUser.Id, - IsProvider = false, - OrganizationUserStatus = orgUser.Status, - OrganizationUserType = orgUser.Type, - PolicyType = PolicyType.OrganizationDataOwnership - }; sutProvider.GetDependency() .GetAsync(orgUser.UserId!.Value) - .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails])); + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [])); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs index 93cbde89ec..dd2f1d76e8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs @@ -38,7 +38,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -60,7 +60,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -86,7 +86,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceive() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( Arg.Any(), Arg.Any>(), Arg.Any()); @@ -172,10 +172,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -210,7 +210,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } private static IPolicyRepository ArrangePolicyRepository(IEnumerable policyDetails) @@ -251,7 +251,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -273,7 +273,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -299,7 +299,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( default, default, default); @@ -336,10 +336,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -367,6 +367,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } } diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs new file mode 100644 index 0000000000..712ad7d62e --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs @@ -0,0 +1,53 @@ +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; + + +public class CreateDefaultCollectionsBulkAsyncTests +{ + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_CreatesDefaultCollections_Success( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_IgnoresAllExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs similarity index 69% rename from test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs rename to test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs index 64dffa473f..0fb4a5b446 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs @@ -6,10 +6,14 @@ using Xunit; namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; -public class UpsertDefaultCollectionsTests +/// +/// Shared tests for CreateDefaultCollections methods - both bulk and non-bulk implementations, +/// as they share the same behavior. Both test suites call the tests in this class. +/// +public static class CreateDefaultCollectionsSharedTests { - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection( + public static async Task CreatesDefaultCollections_Success( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -21,14 +25,13 @@ public class UpsertDefaultCollectionsTests var resultOrganizationUsers = await Task.WhenAll( CreateUserForOrgAsync(userRepository, organizationUserRepository, organization), CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) - ); + ); - - var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); // Assert await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); @@ -36,8 +39,8 @@ public class UpsertDefaultCollectionsTests await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); } - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist( + public static async Task CreatesForNewUsersOnly_AndIgnoresExistingUsers( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -51,31 +54,30 @@ public class UpsertDefaultCollectionsTests CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) ); - var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id); + var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; + await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers); - await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers); - - var newOrganizationUsers = new List() + var newOrganizationUsers = new List { await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) }; - var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers); - var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers).ToList(); + var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id).ToList(); // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); // Assert - await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id); + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, affectedOrgUsers, organization.Id); await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers); } - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne( + public static async Task IgnoresAllExistingUsers( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -89,26 +91,29 @@ public class UpsertDefaultCollectionsTests CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) ); - var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; + await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); - await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); + // Act - Try to create again, should silently filter and not create duplicates + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); - // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); - - // Assert + // Assert - Original collections should remain unchanged, still only one per user await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); } - private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository, - Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName, + private static async Task CreateUsersWithExistingDefaultCollectionsAsync( + Func, string, Task> createDefaultCollectionsFunc, + ICollectionRepository collectionRepository, + Guid organizationId, + IEnumerable affectedOrgUserIds, + string defaultCollectionName, OrganizationUser[] resultOrganizationUsers) { - await collectionRepository.UpsertDefaultCollectionsAsync(organizationId, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organizationId, affectedOrgUserIds, defaultCollectionName); await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId); } @@ -131,7 +136,6 @@ public class UpsertDefaultCollectionsTests private static async Task CreateUserForOrgAsync(IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, Organization organization) { - var user = await userRepository.CreateTestUserAsync(); var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user); diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs new file mode 100644 index 0000000000..bd894e9ca5 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs @@ -0,0 +1,52 @@ +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; + +public class CreateDefaultCollectionsAsyncTests +{ + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_CreatesDefaultCollections_Success( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_IgnoresAllExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } +} diff --git a/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql b/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql new file mode 100644 index 0000000000..c7935db5e8 --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql @@ -0,0 +1,70 @@ +-- Creates default user collections for organization users +-- Filters out existing default collections at database level +CREATE OR ALTER PROCEDURE [dbo].[Collection_CreateDefaultCollections] + @OrganizationId UNIQUEIDENTIFIER, + @DefaultCollectionName VARCHAR(MAX), + @OrganizationUserCollectionIds AS [dbo].[TwoGuidIdArray] READONLY -- OrganizationUserId, CollectionId +AS +BEGIN + SET NOCOUNT ON + + DECLARE @Now DATETIME2(7) = GETUTCDATE() + + -- Filter to only users who don't have default collections + SELECT ids.Id1, ids.Id2 + INTO #FilteredIds + FROM @OrganizationUserCollectionIds ids + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionUser] cu + INNER JOIN [dbo].[Collection] c ON c.Id = cu.CollectionId + WHERE c.OrganizationId = @OrganizationId + AND c.[Type] = 1 -- CollectionType.DefaultUserCollection + AND cu.OrganizationUserId = ids.Id1 + ); + + -- Insert collections only for users who don't have default collections yet + INSERT INTO [dbo].[Collection] + ( + [Id], + [OrganizationId], + [Name], + [CreationDate], + [RevisionDate], + [Type], + [ExternalId], + [DefaultUserCollectionEmail] + ) + SELECT + ids.Id2, -- CollectionId + @OrganizationId, + @DefaultCollectionName, + @Now, + @Now, + 1, -- CollectionType.DefaultUserCollection + NULL, + NULL + FROM + #FilteredIds ids; + + -- Insert collection user mappings + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + ids.Id2, -- CollectionId + ids.Id1, -- OrganizationUserId + 0, -- ReadOnly = false + 0, -- HidePasswords = false + 1 -- Manage = true + FROM + #FilteredIds ids; + + DROP TABLE #FilteredIds; +END +GO From aa33a67aeeeea602dbe96a483e1f69a94744fce4 Mon Sep 17 00:00:00 2001 From: Justin Baur <19896123+justindbaur@users.noreply.github.com> Date: Fri, 16 Jan 2026 10:33:17 -0500 Subject: [PATCH 03/12] [PM-30858] Fix excessive logs (#6860) * Add tests showing issue & workaround - `AddSerilogFileLogging_LegacyConfig_InfoLogs_DoNotFillUpFile` fails - `AddSerilogFileLogging_LegacyConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile` fails - `AddSerilogFileLogging_NewConfig_InfoLogs_DoNotFillUpFile` fails - `AddSerilogFileLogging_NewConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile` works * Allow customization of LogLevel with legacy path format config * Lower default logging levels * Delete tests now that log levels have been customized --- .../src/Scim/appsettings.Production.json | 6 +- src/Admin/appsettings.Production.json | 6 +- src/Api/appsettings.Production.json | 6 +- src/Core/Utilities/LoggerFactoryExtensions.cs | 38 +++++++----- src/Events/appsettings.Production.json | 6 +- .../appsettings.Production.json | 6 +- src/Icons/appsettings.Production.json | 6 +- src/Identity/appsettings.Production.json | 6 +- src/Notifications/appsettings.Production.json | 6 +- .../Utilities/LoggerFactoryExtensionsTests.cs | 59 ++++++++++++++++++- 10 files changed, 96 insertions(+), 49 deletions(-) diff --git a/bitwarden_license/src/Scim/appsettings.Production.json b/bitwarden_license/src/Scim/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/bitwarden_license/src/Scim/appsettings.Production.json +++ b/bitwarden_license/src/Scim/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Admin/appsettings.Production.json b/src/Admin/appsettings.Production.json index 9f797f3111..1d852abfed 100644 --- a/src/Admin/appsettings.Production.json +++ b/src/Admin/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Api/appsettings.Production.json b/src/Api/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/src/Api/appsettings.Production.json +++ b/src/Api/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index b950e30d5d..f3330f0792 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -1,4 +1,5 @@ -using Microsoft.AspNetCore.Hosting; +using System.Globalization; +using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -8,7 +9,7 @@ namespace Bit.Core.Utilities; public static class LoggerFactoryExtensions { /// - /// + /// /// /// /// @@ -21,10 +22,12 @@ public static class LoggerFactoryExtensions return; } + IConfiguration loggingConfiguration; + // If they have begun using the new settings location, use that if (!string.IsNullOrEmpty(context.Configuration["Logging:PathFormat"])) { - logging.AddFile(context.Configuration.GetSection("Logging")); + loggingConfiguration = context.Configuration.GetSection("Logging"); } else { @@ -40,28 +43,35 @@ public static class LoggerFactoryExtensions var projectName = loggingOptions.ProjectName ?? context.HostingEnvironment.ApplicationName; + string pathFormat; + if (loggingOptions.LogRollBySizeLimit.HasValue) { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "log.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}.log"); - - logging.AddFile( - pathFormat: pathFormat, - fileSizeLimitBytes: loggingOptions.LogRollBySizeLimit.Value - ); } else { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "{Date}.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}_{{Date}}.log"); - - logging.AddFile( - pathFormat: pathFormat - ); } + + // We want to rely on Serilog using the configuration section to have customization of the log levels + // so we make a custom configuration source for them based on the legacy values and allow overrides from + // the new location. + loggingConfiguration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + {"PathFormat", pathFormat}, + {"FileSizeLimitBytes", loggingOptions.LogRollBySizeLimit?.ToString(CultureInfo.InvariantCulture)} + }) + .AddConfiguration(context.Configuration.GetSection("Logging")) + .Build(); } + + logging.AddFile(loggingConfiguration); }); } diff --git a/src/Events/appsettings.Production.json b/src/Events/appsettings.Production.json index 010f02f8cd..9a10621264 100644 --- a/src/Events/appsettings.Production.json +++ b/src/Events/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/EventsProcessor/appsettings.Production.json b/src/EventsProcessor/appsettings.Production.json index 1cce4a9ed3..d57bf98b55 100644 --- a/src/EventsProcessor/appsettings.Production.json +++ b/src/EventsProcessor/appsettings.Production.json @@ -1,10 +1,8 @@ { "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Icons/appsettings.Production.json b/src/Icons/appsettings.Production.json index 828e8c61cc..19d21f7260 100644 --- a/src/Icons/appsettings.Production.json +++ b/src/Icons/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Identity/appsettings.Production.json b/src/Identity/appsettings.Production.json index 4897a7d8b1..14471b5fb6 100644 --- a/src/Identity/appsettings.Production.json +++ b/src/Identity/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Notifications/appsettings.Production.json b/src/Notifications/appsettings.Production.json index 010f02f8cd..735c70e481 100644 --- a/src/Notifications/appsettings.Production.json +++ b/src/Notifications/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs b/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs index 81311cb802..ffeb3fa2e7 100644 --- a/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs +++ b/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs @@ -74,8 +74,7 @@ public class LoggerFactoryExtensionsTests logger.LogWarning("This is a test"); - // Writing to the file is buffered, give it a little time to flush - await Task.Delay(5); + await provider.DisposeAsync(); var logFile = Assert.Single(tempDir.EnumerateFiles("Logs/*.log")); @@ -90,13 +89,67 @@ public class LoggerFactoryExtensionsTests logFileContents ); } + + [Fact] + public async Task AddSerilogFileLogging_LegacyConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile() + { + await AssertSmallFileAsync((tempDir, config) => + { + config["GlobalSettings:LogDirectory"] = tempDir; + config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning"; + }); + } + + [Fact] + public async Task AddSerilogFileLogging_NewConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile() + { + await AssertSmallFileAsync((tempDir, config) => + { + config["Logging:PathFormat"] = Path.Combine(tempDir, "log.txt"); + config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning"; + }); + } + + private static async Task AssertSmallFileAsync(Action> configure) + { + using var tempDir = new TempDirectory(); + var config = new Dictionary(); + + configure(tempDir.Directory, config); + + var provider = GetServiceProvider(config, "Production"); + + var loggerFactory = provider.GetRequiredService(); + var microsoftLogger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Testing"); + + for (var i = 0; i < 100; i++) + { + microsoftLogger.LogInformation("Tons of useless information"); + } + + var otherLogger = loggerFactory.CreateLogger("Bitwarden"); + + for (var i = 0; i < 5; i++) + { + otherLogger.LogInformation("Mildly more useful information but not as frequent."); + } + + await provider.DisposeAsync(); + + var logFiles = Directory.EnumerateFiles(tempDir.Directory, "*.txt", SearchOption.AllDirectories); + var logFile = Assert.Single(logFiles); + + using var fr = File.OpenRead(logFile); + Assert.InRange(fr.Length, 0, 1024); + } + private static IEnumerable GetProviders(Dictionary initialData, string environment = "Production") { var provider = GetServiceProvider(initialData, environment); return provider.GetServices(); } - private static IServiceProvider GetServiceProvider(Dictionary initialData, string environment) + private static ServiceProvider GetServiceProvider(Dictionary initialData, string environment) { var config = new ConfigurationBuilder() .AddInMemoryCollection(initialData) From 8d30fbcc8abe424e604073e721ceaeb365c4fba1 Mon Sep 17 00:00:00 2001 From: Stephon Brown Date: Fri, 16 Jan 2026 18:13:57 -0500 Subject: [PATCH 04/12] Billing/pm 30882/defect pm coupon removed on upgrade (#6863) * fix(billing): update coupon check logic * tests(billing): update tests and add plan check test --- .../SubscriptionUpdatedHandler.cs | 11 ++- .../SubscriptionUpdatedHandlerTests.cs | 89 +++++++++++++++++++ 2 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index c10368d8c0..9e20bd3191 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -275,17 +275,24 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler .PreviousAttributes .ToObject() as Subscription; + // Get all plan IDs that include Secrets Manager support to check if the organization has secret manager in the + // previous and/or current subscriptions. + var planIdsOfPlansWithSecretManager = (await _pricingClient.ListPlans()) + .Where(orgPlan => orgPlan.SupportsSecretsManager && orgPlan.SecretsManager.StripeSeatPlanId != null) + .Select(orgPlan => orgPlan.SecretsManager.StripeSeatPlanId) + .ToHashSet(); + // This being false doesn't necessarily mean that the organization doesn't subscribe to Secrets Manager. // If there are changes to any subscription item, Stripe sends every item in the subscription, both // changed and unchanged. var previousSubscriptionHasSecretsManager = previousSubscription?.Items is not null && previousSubscription.Items.Any( - previousSubscriptionItem => previousSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + previousSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(previousSubscriptionItem.Plan.Id)); var currentSubscriptionHasSecretsManager = subscription.Items.Any( - currentSubscriptionItem => currentSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + currentSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(currentSubscriptionItem.Plan.Id)); if (!previousSubscriptionHasSecretsManager || currentSubscriptionHasSecretsManager) { diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 182f09e163..2259d846b7 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -11,6 +11,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; @@ -654,6 +655,8 @@ public class SubscriptionUpdatedHandlerTests var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); var parsedEvent = new Event { @@ -693,6 +696,92 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade.Received(1).DeleteCustomerDiscount(subscription.CustomerId); await _stripeFacade.Received(1).DeleteSubscriptionDiscount(subscription.Id); } + [Fact] + public async Task + HandleAsync_WhenUpgradingPlan_AndPreviousPlanHasSecretsManagerTrial_AndCurrentPlanHasSecretsManagerTrial_DoesNotRemovePasswordManagerCoupon() + { + // Arrange + var organizationId = Guid.NewGuid(); + var subscription = new Subscription + { + Id = "sub_123", + Status = StripeSubscriptionStatus.Active, + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + }, + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } + } + ] + }, + Customer = new Customer + { + Balance = 0, + Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } } + }, + Discounts = [new Discount { Coupon = new Coupon { Id = "sm-standalone" } }], + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } + }; + + // Note: The organization plan is still the previous plan because the subscription is updated before the organization is updated + var organization = new Organization { Id = organizationId, PlanType = PlanType.TeamsAnnually2023 }; + + var plan = new Teams2023Plan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType) + .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(new + { + items = new + { + data = new[] + { + new { plan = new { id = "secrets-manager-teams-seat-annually" } }, + } + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-teams-seat-annually" } }, + ] + } + }) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(organizationId, null, null)); + + _organizationRepository.GetByIdAsync(organizationId) + .Returns(organization); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.DidNotReceive().DeleteCustomerDiscount(subscription.CustomerId); + await _stripeFacade.DidNotReceive().DeleteSubscriptionDiscount(subscription.Id); + } [Theory] [MemberData(nameof(GetNonActiveSubscriptions))] From ad19efcff7a4dacb3d538689479867dd780e14eb Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Sat, 17 Jan 2026 10:47:21 +1000 Subject: [PATCH 05/12] [PM-22236] Fix invited accounts stuck in intermediate claimed status (#6810) * Exclude invited users from claimed domain checks. These users should be excluded by the JOIN on UserId, but it's a known issue that some invited users have this FK set. --- .../Repositories/IOrganizationRepository.cs | 4 +- .../Repositories/OrganizationRepository.cs | 3 +- ...erReadByClaimedOrganizationDomainsQuery.cs | 2 + ...dByOrganizationIdWithClaimedDomains_V2.sql | 3 +- ...anization_ReadByClaimedUserEmailDomain.sql | 3 +- .../GetByVerifiedUserEmailDomainAsyncTests.cs | 335 ++++++++++++++++++ .../OrganizationRepositoryTests.cs | 267 +------------- ...rganizationWithClaimedDomainsAsyncTests.cs | 197 ++++++++++ .../OrganizationUserRepositoryTests.cs | 194 ---------- ...0_ExcludeInvitedUsersFromClaimedDomain.sql | 24 ++ ...cludeInvitedUsersFromClaimedDomains_V2.sql | 29 ++ 11 files changed, 606 insertions(+), 455 deletions(-) create mode 100644 test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs create mode 100644 test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs create mode 100644 util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql create mode 100644 util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql diff --git a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs index da7a77000b..d79923fdd1 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs @@ -21,7 +21,9 @@ public interface IOrganizationRepository : IRepository Task> GetOwnerEmailAddressesById(Guid organizationId); /// - /// Gets the organizations that have a verified domain matching the user's email domain. + /// Gets the organizations that have claimed the user's account. Currently, only one organization may claim a user. + /// This requires that the organization has claimed the user's domain and the user is an organization member. + /// It excludes invited members. /// Task> GetByVerifiedUserEmailDomainAsync(Guid userId); diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 88410facf5..93c8cd304c 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -325,7 +325,8 @@ public class OrganizationRepository : Repository od.OrganizationId == _organizationId && od.VerifiedDate != null && diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql index 64f3d81e08..4f781d2cc9 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql @@ -8,13 +8,14 @@ BEGIN SELECT * FROM [dbo].[OrganizationUserView] WHERE [OrganizationId] = @OrganizationId + AND [Status] != 0 -- Exclude invited users ), UserDomains AS ( SELECT U.[Id], U.[EmailDomain] FROM [dbo].[UserEmailDomainView] U WHERE EXISTS ( SELECT 1 - FROM [dbo].[OrganizationDomainView] OD + FROM [dbo].[OrganizationDomainView] OD WHERE OD.[OrganizationId] = @OrganizationId AND OD.[VerifiedDate] IS NOT NULL AND OD.[DomainName] = U.[EmailDomain] diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql index 583f548c8b..ee14c2c52a 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql @@ -6,7 +6,7 @@ BEGIN WITH CTE_User AS ( SELECT - U.*, + U.[Id], SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain FROM dbo.[UserView] U WHERE U.[Id] = @UserId @@ -19,4 +19,5 @@ BEGIN WHERE OD.[VerifiedDate] IS NOT NULL AND CU.EmailDomain = OD.[DomainName] AND O.[Enabled] = 1 + AND OU.[Status] != 0 -- Exclude invited users END diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs new file mode 100644 index 0000000000..6dd7aafca4 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs @@ -0,0 +1,335 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationRepository; + +public class GetByVerifiedUserEmailDomainAsyncTests +{ + [Theory, DatabaseData] + public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user1 = await userRepository.CreateAsync(new User + { + Name = "Test User 1", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user2 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@x-{domainName}", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user3 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@{domainName}.example.com", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3); + + var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id); + var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id); + var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id); + + Assert.NotEmpty(user1Response); + Assert.Equal(organization.Id, user1Response.First().Id); + Assert.Empty(user2Response); + Assert.Empty(user3Response); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.Empty(result); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization1 = await organizationRepository.CreateTestOrganizationAsync(); + var organization2 = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain1 = new OrganizationDomain + { + OrganizationId = organization1.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain1.SetNextRunDate(12); + organizationDomain1.SetJobRunCount(); + organizationDomain1.SetVerifiedDate(); + await organizationDomainRepository.CreateAsync(organizationDomain1); + + var organizationDomain2 = new OrganizationDomain + { + OrganizationId = organization2.Id, + DomainName = domainName, + Txt = "btw+67890", + }; + organizationDomain2.SetNextRunDate(12); + organizationDomain2.SetJobRunCount(); + organizationDomain2.SetVerifiedDate(); + await organizationDomainRepository.CreateAsync(organizationDomain2); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization1, user); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization2, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.Equal(2, result.Count); + Assert.Contains(result, org => org.Id == organization1.Id); + Assert.Contains(result, org => org.Id == organization2.Id); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty( + IOrganizationRepository organizationRepository) + { + var nonExistentUserId = Guid.NewGuid(); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId); + + Assert.Empty(result); + } + + /// + /// Tests an edge case where some invited users are created linked to a UserId. + /// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly + /// exclude such users from the results without relying on the inner join only. + /// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused + /// any issues to date and we want to minimize edge cases. + /// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405 + /// + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithInvitedUserWithUserId_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + // Create invited user with matching email domain but UserId set (edge case) + await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = user.Id, + Email = user.Email, + Status = OrganizationUserStatusType.Invited, + }); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + // Invited users should be excluded even if they have UserId set + Assert.Empty(result); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithAcceptedUser_ReturnsOrganization( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateAcceptedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.NotEmpty(result); + Assert.Equal(organization.Id, result.First().Id); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithRevokedUser_ReturnsOrganization( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateRevokedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.NotEmpty(result); + Assert.Equal(organization.Id, result.First().Id); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs index 67e2c1910b..52b1e7484b 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs @@ -8,254 +8,7 @@ namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories; public class OrganizationRepositoryTests { - [DatabaseTheory, DatabaseData] - public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user1 = await userRepository.CreateAsync(new User - { - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user2 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@x-{domainName}", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user3 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@{domainName}.example.com", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL - Plan = "Test", // TODO: EF does not enforce this being NOT NULL - PrivateKey = "privatekey", - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetVerifiedDate(); - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user2.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user3.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id); - var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id); - var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id); - - Assert.NotEmpty(user1Response); - Assert.Equal(organization.Id, user1Response.First().Id); - Assert.Empty(user2Response); - Assert.Empty(user3Response); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user = await userRepository.CreateAsync(new User - { - Name = "Test User", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey", - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey", - }); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); - - Assert.Empty(result); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user = await userRepository.CreateAsync(new User - { - Name = "Test User", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization1 = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org 1 {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey1", - }); - - var organization2 = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org 2 {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey2", - }); - - var organizationDomain1 = new OrganizationDomain - { - OrganizationId = organization1.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain1.SetNextRunDate(12); - organizationDomain1.SetJobRunCount(); - organizationDomain1.SetVerifiedDate(); - await organizationDomainRepository.CreateAsync(organizationDomain1); - - var organizationDomain2 = new OrganizationDomain - { - OrganizationId = organization2.Id, - DomainName = domainName, - Txt = "btw+67890", - }; - organizationDomain2.SetNextRunDate(12); - organizationDomain2.SetJobRunCount(); - organizationDomain2.SetVerifiedDate(); - await organizationDomainRepository.CreateAsync(organizationDomain2); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization1.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization2.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey2", - }); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); - - Assert.Equal(2, result.Count); - Assert.Contains(result, org => org.Id == organization1.Id); - Assert.Contains(result, org => org.Id == organization2.Id); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty( - IOrganizationRepository organizationRepository) - { - var nonExistentUserId = Guid.NewGuid(); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId); - - Assert.Empty(result); - } - - - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetManyByIdsAsync_ExistingOrganizations_ReturnsOrganizations(IOrganizationRepository organizationRepository) { var email = "test@email.com"; @@ -287,7 +40,7 @@ public class OrganizationRepositoryTests await organizationRepository.DeleteAsync(organization2); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithUsersAndSponsorships_ReturnsCorrectCounts( IUserRepository userRepository, IOrganizationRepository organizationRepository, @@ -356,7 +109,7 @@ public class OrganizationRepositoryTests Assert.Equal(4, result.Total); // Total occupied seats } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithNoUsersOrSponsorships_ReturnsZero( IOrganizationRepository organizationRepository) { @@ -372,7 +125,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyRevokedUsers_ReturnsZero( IUserRepository userRepository, IOrganizationRepository organizationRepository, @@ -399,7 +152,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyExpiredSponsorships_ReturnsZero( IOrganizationRepository organizationRepository, IOrganizationSponsorshipRepository organizationSponsorshipRepository) @@ -424,7 +177,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task IncrementSeatCountAsync_IncrementsSeatCount(IOrganizationRepository organizationRepository) { var organization = await organizationRepository.CreateTestOrganizationAsync(); @@ -438,7 +191,7 @@ public class OrganizationRepositoryTests Assert.Equal(8, result.Seats); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task IncrementSeatCountAsync_GivenOrganizationHasNotChangedSeatCountBefore_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved( IOrganizationRepository sutRepository) { @@ -462,7 +215,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task IncrementSeatCountAsync_GivenOrganizationHasChangedSeatCountBeforeAndRecordExists_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved( IOrganizationRepository sutRepository) { @@ -487,7 +240,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task GetOrganizationsForSubscriptionSyncAsync_GivenOrganizationHasChangedSeatCount_WhenGettingOrgsToUpdate_ThenReturnsOrgSubscriptionUpdate( IOrganizationRepository sutRepository) { @@ -510,7 +263,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task UpdateSuccessfulOrganizationSyncStatusAsync_GivenOrganizationHasChangedSeatCount_WhenUpdatingStatus_ThenSuccessfullyUpdatesOrgSoItDoesntSync( IOrganizationRepository sutRepository) { diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs new file mode 100644 index 0000000000..6fa395751b --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs @@ -0,0 +1,197 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationUserRepository; + +public class GetManyByOrganizationWithClaimedDomainsAsyncTests +{ + [Theory, DatabaseData] + public async Task WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user1 = await userRepository.CreateAsync(new User + { + Name = "Test User 1", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user2 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@x-{domainName}", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user3 = await userRepository.CreateAsync(new User + { + Name = "Test User 3", + Email = $"test+{id}@{domainName}.example.com", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + var orgUser1 = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + Assert.Single(result); + Assert.Equal(orgUser1.Id, result.Single().Id); + } + + [Theory, DatabaseData] + public async Task WithNoVerifiedDomain_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + // Create domain but do NOT verify it + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetNextRunDate(12); + // Note: NOT calling SetVerifiedDate() + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + Assert.Empty(result); + } + + /// + /// Tests an edge case where some invited users are created linked to a UserId. + /// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly + /// exclude such users from the results without relying on the inner join only. + /// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused + /// any issues to date and we want to minimize edge cases. + /// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405 + /// + [Theory, DatabaseData] + public async Task WithVerifiedDomain_ExcludesInvitedUsers( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var invitedUser = await userRepository.CreateAsync(new User + { + Name = "Invited User", + Email = $"invited+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var confirmedUser = await userRepository.CreateAsync(new User + { + Name = "Confirmed User", + Email = $"confirmed+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + // Create invited user with UserId set (edge case - should be excluded even with UserId linked) + var invitedOrgUser = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = invitedUser.Id, // Edge case: invited user with UserId set + Email = invitedUser.Email, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }); + + // Create confirmed user linked by UserId only (no Email field set) + var confirmedOrgUser = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, confirmedUser); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + var claimedUser = Assert.Single(result); + Assert.Equal(confirmedOrgUser.Id, claimedUser.Id); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs index 1c433d0e6e..b77406abf5 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs @@ -599,136 +599,6 @@ public class OrganizationUserRepositoryTests Assert.Null(orgWithoutSsoDetails.SsoConfig); } - [DatabaseTheory, DatabaseData] - public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user1 = await userRepository.CreateAsync(new User - { - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user2 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@x-{domainName}", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user3 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@{domainName}.example.com", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL - Plan = "Test", // TODO: EF does not enforce this being NOT NULL - PrivateKey = "privatekey", - UsePolicies = false, - UseSso = false, - UseKeyConnector = false, - UseScim = false, - UseGroups = false, - UseDirectory = false, - UseEvents = false, - UseTotp = false, - Use2fa = false, - UseApi = false, - UseResetPassword = false, - UseSecretsManager = false, - SelfHost = false, - UsersGetPremium = false, - UseCustomPermissions = false, - Enabled = true, - UsePasswordManager = false, - LimitCollectionCreation = false, - LimitCollectionDeletion = false, - LimitItemDeletion = false, - AllowAdminAccessToAllCollectionItems = false, - UseRiskInsights = false, - UseAdminSponsoredFamilies = false, - UsePhishingBlocker = false, - UseDisableSmAdsForUsers = false, - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetVerifiedDate(); - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Owner, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user2.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user3.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); - - Assert.NotNull(responseModel); - Assert.Single(responseModel); - Assert.Equal(orgUser1.Id, responseModel.Single().Id); - } - [DatabaseTheory, DatabaseData] public async Task CreateManyAsync_NoId_Works(IOrganizationRepository organizationRepository, IUserRepository userRepository, @@ -1237,70 +1107,6 @@ public class OrganizationUserRepositoryTests Assert.DoesNotContain(user1Result.Collections, c => c.Id == defaultUserCollection.Id); } - [DatabaseTheory, DatabaseData] - public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithNoVerifiedDomain_ReturnsEmpty( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - var requestTime = DateTime.UtcNow; - - var user1 = await userRepository.CreateAsync(new User - { - Id = CoreHelpers.GenerateComb(), - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - CreationDate = requestTime, - RevisionDate = requestTime, - AccountRevisionDate = requestTime - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Id = CoreHelpers.GenerateComb(), - Name = $"Test Org {id}", - BillingEmail = user1.Email, - Plan = "Test", - Enabled = true, - CreationDate = requestTime, - RevisionDate = requestTime - }); - - // Create domain but do NOT verify it - var organizationDomain = new OrganizationDomain - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - CreationDate = requestTime - }; - organizationDomain.SetNextRunDate(12); - // Note: NOT calling SetVerifiedDate() - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Owner, - CreationDate = requestTime, - RevisionDate = requestTime - }); - - var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); - - Assert.NotNull(responseModel); - Assert.Empty(responseModel); - } - [DatabaseTheory, DatabaseData] public async Task DeleteAsync_WithNullEmail_DoesNotSetDefaultUserCollectionEmail(IUserRepository userRepository, ICollectionRepository collectionRepository, diff --git a/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql b/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql new file mode 100644 index 0000000000..788fa02b7c --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql @@ -0,0 +1,24 @@ +CREATE OR ALTER PROCEDURE [dbo].[Organization_ReadByClaimedUserEmailDomain] + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + WITH CTE_User AS ( + SELECT + U.[Id], + SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain + FROM dbo.[UserView] U + WHERE U.[Id] = @UserId + ) + SELECT O.* + FROM CTE_User CU + INNER JOIN dbo.[OrganizationUserView] OU ON CU.[Id] = OU.[UserId] + INNER JOIN dbo.[OrganizationView] O ON OU.[OrganizationId] = O.[Id] + INNER JOIN dbo.[OrganizationDomainView] OD ON OU.[OrganizationId] = OD.[OrganizationId] + WHERE OD.[VerifiedDate] IS NOT NULL + AND CU.EmailDomain = OD.[DomainName] + AND O.[Enabled] = 1 + AND OU.[Status] != 0 -- Exclude invited users +END +GO diff --git a/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql b/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql new file mode 100644 index 0000000000..b7be5fd0e0 --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql @@ -0,0 +1,29 @@ +CREATE OR ALTER PROCEDURE [dbo].[OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2] + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + WITH OrgUsers AS ( + SELECT * + FROM [dbo].[OrganizationUserView] + WHERE [OrganizationId] = @OrganizationId + AND [Status] != 0 -- Exclude invited users + ), + UserDomains AS ( + SELECT U.[Id], U.[EmailDomain] + FROM [dbo].[UserEmailDomainView] U + WHERE EXISTS ( + SELECT 1 + FROM [dbo].[OrganizationDomainView] OD + WHERE OD.[OrganizationId] = @OrganizationId + AND OD.[VerifiedDate] IS NOT NULL + AND OD.[DomainName] = U.[EmailDomain] + ) + ) + SELECT OU.* + FROM OrgUsers OU + JOIN UserDomains UD ON OU.[UserId] = UD.[Id] + OPTION (RECOMPILE); +END +GO From c37412bacbdbbe07027ac26cda750ba4ec9be736 Mon Sep 17 00:00:00 2001 From: Todd Martin <106564991+trmartin4@users.noreply.github.com> Date: Tue, 20 Jan 2026 10:03:33 -0500 Subject: [PATCH 06/12] chore(flags): Remove pm-1632-redirect-on-sso-required feature flag * Remove feature flag. * Update test title. * Fixed some test failures. * Fixed tests * Removed method that's no longer used. * Removed unneeded directive. --- src/Core/Constants.cs | 1 - .../RequestValidators/BaseRequestValidator.cs | 91 +--- .../CustomTokenRequestValidator.cs | 11 - .../ResourceOwnerPasswordValidator.cs | 8 - .../WebAuthnGrantValidator.cs | 8 - .../BaseRequestValidatorTests.cs | 396 ++++++------------ .../BaseRequestValidatorTestWrapper.cs | 9 - 7 files changed, 144 insertions(+), 380 deletions(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6f42778b6b..10c68ddc42 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -162,7 +162,6 @@ public static class FeatureFlagKeys public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; public const string OrganizationConfirmationEmail = "pm-28402-update-confirmed-to-org-email-template"; public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; - public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; public const string PrefetchPasswordPrelogin = "pm-23801-prefetch-password-prelogin"; public const string PM27086_UpdateAuthenticationApisForInputPassword = "pm-27086-update-authentication-apis-for-input-password"; diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index e07446d49f..289feebdb2 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -4,7 +4,6 @@ using System.Security.Claims; using Bit.Core; -using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Entities; @@ -233,56 +232,14 @@ public abstract class BaseRequestValidator where T : class private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // TODO: Clean up Feature Flag: Remove this if block: PM-28281 - if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) { - validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); - if (!validatorContext.SsoRequired) - { - return true; - } - - // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are - // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and - // review their new recovery token if desired. - // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. - // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been - // evaluated, and recovery will have been performed if requested. - // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect - // to /login. - if (validatorContext.TwoFactorRequired && - validatorContext.TwoFactorRecoveryRequested) - { - SetSsoResult(context, - new Dictionary - { - { - "ErrorModel", - new ErrorResponseModel( - "Two-factor recovery has been performed. SSO authentication is required.") - } - }); - return false; - } - - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return false; + return true; } - else - { - var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); - if (ssoValid) - { - return true; - } - SetValidationErrorResult(context, validatorContext); - return ssoValid; - } + SetValidationErrorResult(context, validatorContext); + return ssoValid; } /// @@ -521,9 +478,6 @@ public abstract class BaseRequestValidator where T : class [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] - protected abstract void SetSsoResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -540,41 +494,6 @@ public abstract class BaseRequestValidator where T : class protected abstract ClaimsPrincipal GetSubject(T context); - /// - /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are - /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. - /// If the GrantType is authorization_code or client_credentials we know the user is trying to login - /// using the SSO flow so they are allowed to continue. - /// - /// user trying to login - /// magic string identifying the grant type requested - /// true if sso required; false if not required or already in process - [Obsolete( - "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] - private async Task RequireSsoLoginAsync(User user, string grantType) - { - if (grantType == "authorization_code" || grantType == "client_credentials") - { - // Already using SSO to authenticate, or logging-in via api key to skip SSO requirement - // allow to authenticate successfully - return false; - } - - // Check if user belongs to any organization with an active SSO policy - var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) - ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired - : await PolicyService.AnyPoliciesApplicableToUserAsync( - user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - if (ssoRequired) - { - return true; - } - - // Default - SSO is not required - return false; - } - private async Task ResetFailedAuthDetailsAsync(User user) { // Early escape if db hit not necessary diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 38a4813ecd..2412c52308 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -194,17 +194,6 @@ public class CustomTokenRequestValidator : BaseRequestValidator customResponse) - { - Debug.Assert(context.Result is not null); - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Sso authentication required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(CustomTokenRequestValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index ea2c021f63..8bfddf24f3 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -152,14 +152,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index e4cd60827e..1563831b81 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -142,14 +142,6 @@ public class WebAuthnGrantValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetValidationErrorResult instead.")] protected override void SetErrorResult(ExtensionGrantValidationContext context, Dictionary customResponse) { diff --git a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs index 677382b138..4b6f681096 100644 --- a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs @@ -18,6 +18,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Identity.IdentityServer; +using Bit.Identity.IdentityServer.RequestValidationConstants; using Bit.Identity.IdentityServer.RequestValidators; using Bit.Identity.Test.Wrappers; using Bit.Test.Common.AutoFixture.Attributes; @@ -130,7 +131,7 @@ public class BaseRequestValidatorTests var logs = _logger.Collector.GetSnapshot(true); Assert.Contains(logs, l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: "); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message); } @@ -161,7 +162,11 @@ public class BaseRequestValidatorTests .ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(false)); - // 5 -> not legacy user + // 5 -> SSO not required + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 6 -> not legacy user _userService.IsLegacyUser(Arg.Any()) .Returns(false); @@ -203,6 +208,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); + // 6 -> SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 7 -> setup user account keys _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -262,6 +272,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); + // 6 -> SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 7 -> setup user account keys _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -326,6 +341,9 @@ public class BaseRequestValidatorTests { "TwoFactorProviders2", new Dictionary { { "Email", null } } } })); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + // Act await _sut.ValidateAsync(context); @@ -368,6 +386,10 @@ public class BaseRequestValidatorTests .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Email, "invalid_token") .Returns(Task.FromResult(false)); + // 5 -> set up SSO required verification to succeed + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + // Act await _sut.ValidateAsync(context); @@ -396,21 +418,25 @@ public class BaseRequestValidatorTests // 1 -> initial validation passes _sut.isValid = true; - // 2 -> set up 2FA as required + // 2 -> set up SSO required verification to succeed + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 3 -> set up 2FA as required _twoFactorAuthenticationValidator .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) .Returns(Task.FromResult(new Tuple(true, null))); - // 3 -> provide invalid remember token (remember token expired) + // 4 -> provide invalid remember token (remember token expired) tokenRequest.Raw["TwoFactorToken"] = "expired_remember_token"; tokenRequest.Raw["TwoFactorProvider"] = "5"; // Remember provider - // 4 -> set up remember token verification to fail + // 5 -> set up remember token verification to fail _twoFactorAuthenticationValidator .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Remember, "expired_remember_token") .Returns(Task.FromResult(false)); - // 5 -> set up dummy BuildTwoFactorResultAsync + // 6 -> set up dummy BuildTwoFactorResultAsync var twoFactorResultDict = new Dictionary { { "TwoFactorProviders", new[] { "0", "1" } }, @@ -446,6 +472,19 @@ public class BaseRequestValidatorTests GrantValidationResult grantResult) { // Arrange + + // SsoRequestValidator sets custom response + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -454,13 +493,17 @@ public class BaseRequestValidatorTests Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(false)); + // Act await _sut.ValidateAsync(context); // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); } // Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled @@ -477,6 +520,20 @@ public class BaseRequestValidatorTests { // Arrange _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + // SsoRequestValidator sets custom response with organization identifier + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + { CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" } + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -485,6 +542,10 @@ public class BaseRequestValidatorTests var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; _policyRequirementQuery.GetAsync(Arg.Any()).Returns(requirement); + // Mock the SSO validator to return false + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(false)); + // Act await _sut.ValidateAsync(context); @@ -492,8 +553,9 @@ public class BaseRequestValidatorTests await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); } [Theory] @@ -519,6 +581,10 @@ public class BaseRequestValidatorTests var requirement = new RequireSsoPolicyRequirement { SsoRequired = false }; _policyRequirementQuery.GetAsync(Arg.Any()).Returns(requirement); + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -561,6 +627,11 @@ public class BaseRequestValidatorTests _policyService.AnyPoliciesApplicableToUserAsync( Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(false)); + + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -603,6 +674,10 @@ public class BaseRequestValidatorTests context.ValidatedTokenRequest.GrantType = grantType; + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -652,13 +727,15 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; var expectedMessage = "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support"; Assert.Equal(expectedMessage, errorResponse.Message); @@ -694,6 +771,10 @@ public class BaseRequestValidatorTests var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -760,6 +841,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -833,6 +916,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -877,6 +962,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -921,6 +1008,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -950,6 +1039,19 @@ public class BaseRequestValidatorTests GrantValidationResult grantResult) { // Arrange + + // SsoRequestValidator sets custom response + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -984,12 +1086,12 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery"); - - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; // Recovery succeeds, then SSO blocks with descriptive message Assert.Equal( - "Two-factor recovery has been performed. SSO authentication is required.", + SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); // Verify recovery was marked @@ -1050,7 +1152,7 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code"); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; // 2FA is checked first (due to recovery code request), fails with 2FA error Assert.Equal( @@ -1132,7 +1234,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); - // 8. Setup user account keys for successful login response + // 8. SSO is not required + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 9. Setup user account keys for successful login response _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -1161,179 +1267,18 @@ public class BaseRequestValidatorTests } /// - /// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used. - /// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement - /// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach. + /// Tests that when SSO validation returns a custom response, (e.g., with organization identifier), + /// that custom response is properly propagated to the result. /// [Theory] [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation( + public async Task ValidateAsync_SsoRequired_PropagatesCustomResponse( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - - // SSO is required via legacy path - _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); - - // Verify legacy path was used - await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( - requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - - // Verify new SsoRequestValidator was NOT called - await _ssoRequestValidator.DidNotReceive().ValidateAsync( - Arg.Any(), Arg.Any(), Arg.Any()); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used - /// instead of the legacy RequireSsoLoginAsync method. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - - // Configure SsoRequestValidator to indicate SSO is required - _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) - .Returns(Task.FromResult(false)); // false = SSO required - - // Set up the ValidationErrorResult that SsoRequestValidator would set - requestContext.ValidationErrorResult = new ValidationResult - { - IsError = true, - Error = "sso_required", - ErrorDescription = "SSO authentication is required." - }; - requestContext.CustomResponse = new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }; - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - - // Verify new SsoRequestValidator was called - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - requestContext); - - // Verify legacy path was NOT used - await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), Arg.Any(), Arg.Any()); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required, - /// authentication continues successfully through the new validation path. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - tokenRequest.ClientId = "web"; - - // SsoRequestValidator returns true (SSO not required) - _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) - .Returns(Task.FromResult(true)); - - // No 2FA required - _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) - .Returns(Task.FromResult(new Tuple(false, null))); - - // Device validation passes - _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) - .Returns(Task.FromResult(true)); - - // User is not legacy - _userService.IsLegacyUser(Arg.Any()).Returns(false); - - _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData - { - PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( - "test-private-key", - "test-public-key" - ) - }); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.False(context.GrantResult.IsError); - await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn); - - // Verify new validator was used - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - requestContext); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response - /// (e.g., with organization identifier), that custom response is properly propagated to the result. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); _sut.isValid = true; tokenRequest.GrantType = OidcConstants.GrantTypes.Password; @@ -1342,13 +1287,13 @@ public class BaseRequestValidatorTests requestContext.ValidationErrorResult = new ValidationResult { IsError = true, - Error = "sso_required", - ErrorDescription = "SSO authentication is required." + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription }; requestContext.CustomResponse = new Dictionary { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }, - { "SsoOrganizationIdentifier", "test-org-identifier" } + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + { CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" } }; var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1365,77 +1310,24 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError); Assert.NotNull(context.GrantResult.CustomResponse); - Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse); + Assert.Contains(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, context.CustomValidatorRequestContext.CustomResponse); Assert.Equal("test-org-identifier", - context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]); + context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]); } /// - /// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery, - /// but SSO is required, the legacy error message is returned (without the recovery-specific message). - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - // Recovery code scenario - tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); - tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code"; - - // 2FA with recovery - _twoFactorAuthenticationValidator - .RequiresTwoFactorAsync(requestContext.User, tokenRequest) - .Returns(Task.FromResult(new Tuple(true, null))); - - _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code") - .Returns(Task.FromResult(true)); - - // SSO is required (legacy check) - _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - - // Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired - Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); - - // But legacy validation path was used - await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( - requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user, + /// Tests that when a recovery code is used for SSO-required user, /// the SsoRequestValidator provides the recovery-specific error message. /// [Theory] [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage( + public async Task ValidateAsync_RecoveryWithSso_CorrectValidatorMessage( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -1457,14 +1349,14 @@ public class BaseRequestValidatorTests requestContext.ValidationErrorResult = new ValidationResult { IsError = true, - Error = "sso_required", - ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required." + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription }; requestContext.CustomResponse = new Dictionary { { - "ErrorModel", - new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") + CustomResponseConstants.ResponseKeys.ErrorModel, + new ErrorResponseModel(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription) } }; @@ -1479,18 +1371,8 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"]; - Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); - - // Verify new validator was used - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - Arg.Is(ctx => ctx.TwoFactorRecoveryRequested)); - - // Verify legacy path was NOT used - await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), Arg.Any(), Arg.Any()); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription, errorResponse.Message); } private BaseRequestValidationContextFake CreateContext( diff --git a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs index b336e4c3c1..ac27c55466 100644 --- a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs +++ b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs @@ -111,15 +111,6 @@ IBaseRequestValidatorTestWrapper context.GrantResult = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); } - [Obsolete] - protected override void SetSsoResult( - BaseRequestValidationContextFake context, - Dictionary customResponse) - { - context.GrantResult = new GrantValidationResult( - TokenRequestErrors.InvalidGrant, "Sso authentication required.", customResponse); - } - protected override Task SetSuccessResult( BaseRequestValidationContextFake context, User user, From 2e4dd061e313caf96e321da374ea5be61d9865d0 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Tue, 20 Jan 2026 09:18:27 -0600 Subject: [PATCH 07/12] [PM-30855] Pay prorated storage adjustment immediately with Braintree for Premium PayPal users (#6850) * fix: Pay prorated storage invoice immediately with Braintree for PayPal users * Run dotnet format --- .../Commands/UpdatePremiumStorageCommand.cs | 47 +++- src/Core/Billing/Services/IStripeAdapter.cs | 1 + .../Services/Implementations/StripeAdapter.cs | 3 + .../UpdatePremiumStorageCommandTests.cs | 230 +++++++++++++++++- 4 files changed, 263 insertions(+), 18 deletions(-) diff --git a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs index 176c77bf57..219f450f1d 100644 --- a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Core.Utilities; @@ -29,6 +30,7 @@ public interface IUpdatePremiumStorageCommand } public class UpdatePremiumStorageCommand( + IBraintreeService braintreeService, IStripeAdapter stripeAdapter, IUserService userService, IPricingClient pricingClient, @@ -49,7 +51,10 @@ public class UpdatePremiumStorageCommand( // Fetch all premium plans and the user's subscription to find which plan they're on var premiumPlans = await pricingClient.ListPremiumPlans(); - var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId); + var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["customer"] + }); // Find the password manager subscription item (seat, not storage) and match it to a plan var passwordManagerItem = subscription.Items.Data.FirstOrDefault(i => @@ -127,13 +132,41 @@ public class UpdatePremiumStorageCommand( }); } - var subscriptionUpdateOptions = new SubscriptionUpdateOptions - { - Items = subscriptionItemOptions, - ProrationBehavior = ProrationBehavior.AlwaysInvoice - }; + var usingPayPal = subscription.Customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); - await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, subscriptionUpdateOptions); + if (usingPayPal) + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.CreateProrations + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + + var draftInvoice = await stripeAdapter.CreateInvoiceAsync(new InvoiceCreateOptions + { + Customer = subscription.CustomerId, + Subscription = subscription.Id, + AutoAdvance = false, + CollectionMethod = CollectionMethod.ChargeAutomatically + }); + + var finalizedInvoice = await stripeAdapter.FinalizeInvoiceAsync(draftInvoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false, Expand = ["customer"] }); + + await braintreeService.PayInvoice(new UserId(user.Id), finalizedInvoice); + } + else + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.AlwaysInvoice + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + } // Update the user's max storage user.MaxStorageGb = maxStorageGb; diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs index 5ec732920e..12ea3d5a7c 100644 --- a/src/Core/Billing/Services/IStripeAdapter.cs +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -24,6 +24,7 @@ public interface IStripeAdapter Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); Task GetInvoiceAsync(string id, InvoiceGetOptions options); Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoiceAsync(InvoiceCreateOptions options); Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); Task> SearchInvoiceAsync(InvoiceSearchOptions options); Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs index cdc7645042..5b90500021 100644 --- a/src/Core/Billing/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -116,6 +116,9 @@ public class StripeAdapter : IStripeAdapter return invoices; } + public Task CreateInvoiceAsync(InvoiceCreateOptions options) => + _invoiceService.CreateAsync(options); + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => _invoiceService.CreatePreviewAsync(options); diff --git a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs index 7b9b68c757..cd9b323f9d 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; +using static Bit.Core.Billing.Constants.StripeConstants; using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; @@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands; public class UpdatePremiumStorageCommandTests { + private readonly IBraintreeService _braintreeService = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly IUserService _userService = Substitute.For(); private readonly IPricingClient _pricingClient = Substitute.For(); @@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests _pricingClient.ListPremiumPlans().Returns([premiumPlan]); _command = new UpdatePremiumStorageCommand( + _braintreeService, _stripeAdapter, _userService, _pricingClient, Substitute.For>()); } - private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null) + private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false) { var items = new List { @@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests }); } + var customer = new Customer + { + Id = "cus_123", + Metadata = isPayPal ? new Dictionary { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary() + }; + return new Subscription { Id = subscriptionId, + CustomerId = "cus_123", + Customer = customer, Items = new StripeList { Data = items @@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, -5); @@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 100); @@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 4); @@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests Assert.True(result.IsT0); // Verify subscription was fetched but NOT updated - await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123"); + await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any()); await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); } @@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123"); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 2); @@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 99); @@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 100)); } + + [Theory, BitAutoData] + public async Task Run_IncreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 5; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with CreateProrations + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify draft invoice was created + await _stripeAdapter.Received(1).CreateInvoiceAsync( + Arg.Is(opts => + opts.Customer == "cus_123" && + opts.Subscription == "sub_123" && + opts.AutoAdvance == false && + opts.CollectionMethod == "charge_automatically")); + + // Verify invoice was finalized + await _stripeAdapter.Received(1).FinalizeInvoiceAsync( + "in_draft", + Arg.Is(opts => + opts.AutoAdvance == false && + opts.Expand.Contains("customer"))); + + // Verify Braintree payment was processed + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + // Verify user was saved + await _userService.Received(1).SaveUserAsync(Arg.Is(u => + u.Id == user.Id && + u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_AddStorageFromZero_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 1; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with new storage item + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Price == "price_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_DecreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 2); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 2 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 3)); + } + + [Theory, BitAutoData] + public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 0); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription item was deleted + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Deleted == true && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 1)); + } } From 439485fc1687692cec131cc413ce4e5a78058cec Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Tue, 20 Jan 2026 09:29:49 -0600 Subject: [PATCH 08/12] Update renewal email copy (#6862) --- .../Renewals/families-2019-renewal.mjml | 2 +- .../Billing/Renewals/premium-renewal.mjml | 2 +- .../Families2019RenewalMailView.html.hbs | 74 +++++++++---------- .../Families2019RenewalMailView.text.hbs | 2 +- .../Premium/PremiumRenewalMailView.html.hbs | 74 +++++++++---------- .../Premium/PremiumRenewalMailView.text.hbs | 2 +- 6 files changed, 78 insertions(+), 78 deletions(-) diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml index 092ae303de..11d82e2039 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml @@ -19,7 +19,7 @@ As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml index a460442a7c..1fe48c9ba9 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -18,7 +18,7 @@ As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. Questions? Contact diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs index 227613999b..0befde11b5 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -203,7 +203,7 @@
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
+ This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. @@ -271,12 +271,12 @@ - + -
+
- +
@@ -364,8 +364,8 @@ - -
- + +
@@ -381,13 +381,13 @@
- +
+ - @@ -404,13 +404,13 @@ -
+ - +
- +
+ - @@ -427,13 +427,13 @@ -
+ - +
- +
+ - @@ -450,13 +450,13 @@ -
+ - +
- +
+ - @@ -473,13 +473,13 @@ -
+ - +
- +
+ - @@ -496,13 +496,13 @@ -
+ - +
- +
+ - @@ -519,13 +519,13 @@ -
+ - +
- +
+ - @@ -546,15 +546,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs index 88d64f9acf..7178548772 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -2,6 +2,6 @@ at {{BaseAnnualRenewalPrice}} + tax. As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. +This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs index a6b2fda0f7..9ce45ef7fe 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -202,7 +202,7 @@ @@ -270,12 +270,12 @@
+ - +
-

+

© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
+ This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
- + -
+
- +
@@ -363,8 +363,8 @@ - -
- + +
@@ -380,13 +380,13 @@
- +
+ - @@ -403,13 +403,13 @@ -
+ - +
- +
+ - @@ -426,13 +426,13 @@ -
+ - +
- +
+ - @@ -449,13 +449,13 @@ -
+ - +
- +
+ - @@ -472,13 +472,13 @@ -
+ - +
- +
+ - @@ -495,13 +495,13 @@ -
+ - +
- +
+ - @@ -518,13 +518,13 @@ -
+ - +
- +
+ - @@ -545,15 +545,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs index 41300d0f96..15ad530a07 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -1,6 +1,6 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. +This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. Questions? Contact support@bitwarden.com From 7fb2822e05114f61c54ea2f6f0795b7c7425df3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 21 Jan 2026 11:27:24 +0000 Subject: [PATCH 09/12] =?UTF-8?q?[PM-28023]=C2=A0Fix=20restoring=20revoked?= =?UTF-8?q?=20invited=20users=20in=20Free=20Organizations=20(#6861)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix null reference when restoring invited users in Free orgs Add null check before querying for other free org ownership. Invited users don't have a UserId yet, causing NullReferenceException. * Add regression test for restoring revoked invited users with null UserId. --- .../v1/RestoreOrganizationUserCommand.cs | 2 +- .../RestoreOrganizationUserCommandTests.cs | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index ec42c8b402..c5b7314730 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -93,7 +93,7 @@ public class RestoreOrganizationUserCommand( .twoFactorIsEnabled; } - if (organization.PlanType == PlanType.Free) + if (organization.PlanType == PlanType.Free && organizationUser.UserId.HasValue) { await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs index 4fa5e92abe..a75345a05d 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs @@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests Arg.Is(x => x != OrganizationUserStatusType.Revoked)); } + [Theory, BitAutoData] + public async Task RestoreUser_InvitedUserInFreeOrganization_Success( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + organization.PlanType = PlanType.Free; + organizationUser.UserId = null; + organizationUser.Key = null; + organizationUser.Status = OrganizationUserStatusType.Revoked; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 1 + }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .PushSyncOrgKeysAsync(Arg.Any()); + } + [Theory, BitAutoData] public async Task RestoreUsers_Success(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, From 75a857055e74c452cfdc96cd4c1143439a3d2f8e Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Wed, 21 Jan 2026 11:52:36 -0600 Subject: [PATCH 10/12] [PM-30697] [PM-30698] Renewal email copy updates (#6875) * feat(families-renewal): Update copy * feat(premium-renewal): Add new var, update copy --- .../Services/Implementations/UpcomingInvoiceHandler.cs | 2 +- .../Mjml/emails/Billing/Renewals/families-2019-renewal.mjml | 4 ++-- .../Mjml/emails/Billing/Renewals/premium-renewal.mjml | 4 ++-- .../Families2019Renewal/Families2019RenewalMailView.html.hbs | 4 ++-- .../Families2019Renewal/Families2019RenewalMailView.text.hbs | 4 ++-- .../Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs | 2 +- .../Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs | 4 ++-- .../Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs | 4 ++-- test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs | 4 ++-- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 004828dc48..ae2a76a7ce 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -627,7 +627,7 @@ public class UpcomingInvoiceHandler( { BaseMonthlyRenewalPrice = (premiumPlan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")), DiscountAmount = $"{coupon.PercentOff}%", - DiscountedMonthlyRenewalPrice = (discountedAnnualRenewalPrice / 12).ToString("C", new CultureInfo("en-US")) + DiscountedAnnualRenewalPrice = discountedAnnualRenewalPrice.ToString("C", new CultureInfo("en-US")) } }; diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml index 11d82e2039..06f60e7724 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml @@ -18,8 +18,8 @@ at {{BaseAnnualRenewalPrice}} + tax. - As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. + As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml index 1fe48c9ba9..defec91f0e 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -17,8 +17,8 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. - As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs index 0befde11b5..2d7c9edf35 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -202,8 +202,8 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs index 7178548772..9f40c88329 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -1,7 +1,7 @@ Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually at {{BaseAnnualRenewalPrice}} + tax. -As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. +As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs index 4006c92a63..0798c7dbc8 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs @@ -5,7 +5,7 @@ namespace Bit.Core.Models.Mail.Billing.Renewal.Premium; public class PremiumRenewalMailView : BaseMailView { public required string BaseMonthlyRenewalPrice { get; set; } - public required string DiscountedMonthlyRenewalPrice { get; set; } + public required string DiscountedAnnualRenewalPrice { get; set; } public required string DiscountAmount { get; set; } } diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs index 9ce45ef7fe..db76520eed 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -201,8 +201,8 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs index 15ad530a07..4b79826f71 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -1,6 +1,6 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. -As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. +As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs index 3b133c7d37..82d6c8acfd 100644 --- a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -280,7 +280,7 @@ public class UpcomingInvoiceHandlerTests email.ToEmails.Contains("user@example.com") && email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && - email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountedAnnualRenewalPrice == discountedPrice.ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == $"{coupon.PercentOff}%" )); } @@ -2436,7 +2436,7 @@ public class UpcomingInvoiceHandlerTests email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == "30%" && - email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US")) + email.View.DiscountedAnnualRenewalPrice == expectedDiscountedPrice.ToString("C", new CultureInfo("en-US")) )); await _mailService.DidNotReceive().SendInvoiceUpcoming( From b686da18dcdeecd93d2774fa348d16960cc4c959 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:01:06 -0600 Subject: [PATCH 11/12] [PM-30626] Fetch provided storage from Pricing Service when determining storage limit (#6845) * Fetch provided storage from Pricing Service * Run dotnet format * Gbubemi's feedback --- .../Services/SendValidationService.cs | 22 ++- .../Services/Implementations/CipherService.cs | 23 ++- .../Services/SendValidationServiceTests.cs | 120 +++++++++++ .../Vault/Services/CipherServiceTests.cs | 186 +++++++++++++++++- 4 files changed, 337 insertions(+), 14 deletions(-) create mode 100644 test/Core.Test/Tools/Services/SendValidationServiceTests.cs diff --git a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs index c545c8b35f..bd987bb396 100644 --- a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs +++ b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -27,6 +28,7 @@ public class SendValidationService : ISendValidationService private readonly GlobalSettings _globalSettings; private readonly ICurrentContext _currentContext; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IPricingClient _pricingClient; @@ -38,7 +40,7 @@ public class SendValidationService : ISendValidationService IUserService userService, IPolicyRequirementQuery policyRequirementQuery, GlobalSettings globalSettings, - + IPricingClient pricingClient, ICurrentContext currentContext) { _userRepository = userRepository; @@ -48,6 +50,7 @@ public class SendValidationService : ISendValidationService _userService = userService; _policyRequirementQuery = policyRequirementQuery; _globalSettings = globalSettings; + _pricingClient = pricingClient; _currentContext = currentContext; } @@ -123,10 +126,19 @@ public class SendValidationService : ISendValidationService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - short limit = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1; - storageBytesRemaining = user.StorageBytesRemaining(limit); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (send.OrganizationId.HasValue) diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index fa2cfbb209..140399a37a 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Platform.Push; @@ -46,6 +47,7 @@ public class CipherService : ICipherService private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IApplicationCacheService _applicationCacheService; private readonly IFeatureService _featureService; + private readonly IPricingClient _pricingClient; public CipherService( ICipherRepository cipherRepository, @@ -65,7 +67,8 @@ public class CipherService : ICipherService IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery, IPolicyRequirementQuery policyRequirementQuery, IApplicationCacheService applicationCacheService, - IFeatureService featureService) + IFeatureService featureService, + IPricingClient pricingClient) { _cipherRepository = cipherRepository; _folderRepository = folderRepository; @@ -85,6 +88,7 @@ public class CipherService : ICipherService _policyRequirementQuery = policyRequirementQuery; _applicationCacheService = applicationCacheService; _featureService = featureService; + _pricingClient = pricingClient; } public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, @@ -943,10 +947,19 @@ public class CipherService : ICipherService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (cipher.OrganizationId.HasValue) diff --git a/test/Core.Test/Tools/Services/SendValidationServiceTests.cs b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs new file mode 100644 index 0000000000..8adce1a29f --- /dev/null +++ b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs @@ -0,0 +1,120 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Tools.Entities; +using Bit.Core.Tools.Enums; +using Bit.Core.Tools.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Tools.Services; + +[SutProviderCustomize] +public class SendValidationServiceTests +{ + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgGrantedPremiumUser_UsesPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.Storage = 1024L * 1024L * 1024L; // 1 GB used + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = false; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + var premiumPlan = new Plan + { + Storage = new Purchasable { Provided = 5 } + }; + sutProvider.GetDependency().GetAvailablePremiumPlan().Returns(premiumPlan); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert + await sutProvider.GetDependency().Received(1).GetAvailablePremiumPlan(); + Assert.True(result > 0); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_IndividualPremium_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = true; + user.MaxStorageGb = 10; + user.EmailVerified = true; + + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for individual premium users + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_SelfHosted_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = true; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for self-hosted + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgSend_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + Organization org) + { + // Arrange + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + org.MaxStorageGb = 100; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for org sends + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } +} diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 058c6f68ab..5fc92a9d39 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -6,6 +6,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -2228,10 +2230,6 @@ public class CipherServiceTests .PushSyncCiphersAsync(deletingUserId); } - - - - [Theory] [OrganizationCipherCustomize] [BitAutoData] @@ -2387,6 +2385,186 @@ public class CipherServiceTests ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id)))); } + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_UsesStorageFromPricingClient( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium (Premium = false), but with org-granted premium access + var user = new User + { + Id = savingUserId, + Premium = false, // User does not have personal premium + MaxStorageGb = null, // No personal storage allocation + Storage = 0 // No storage used yet + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + // User has premium access through their organization + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + // Mock GlobalSettings to indicate cloud (not self-hosted) + sutProvider.GetDependency().SelfHosted = false; + + // Mock the PricingClient to return a premium plan with 1 GB of storage + var premiumPlan = new Plan + { + Name = "Premium", + Available = true, + Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 }, + Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 } + }; + sutProvider.GetDependency() + .GetAvailablePremiumPlan() + .Returns(premiumPlan); + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, Arg.Any(), Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ReplaceAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate); + + // Assert - PricingClient was called to get the premium plan storage + await sutProvider.GetDependency().Received(1).GetAvailablePremiumPlan(); + + // Assert - Attachment was uploaded successfully + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_ExceedsStorage_ThrowsBadRequest( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium, with org-granted access, but storage is full + var user = new User + { + Id = savingUserId, + Premium = false, + MaxStorageGb = null, + Storage = 1073741824 // 1 GB already used (equals the provided storage) + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency().SelfHosted = false; + + // Premium plan provides 1 GB of storage + var premiumPlan = new Plan + { + Name = "Premium", + Available = true, + Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 }, + Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 } + }; + sutProvider.GetDependency() + .GetAvailablePremiumPlan() + .Returns(premiumPlan); + + // Act & Assert - Should throw because storage is full + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate)); + Assert.Contains("Not enough storage available", exception.Message); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_SelfHosted_UsesConstantStorage( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium, but with org-granted premium access + var user = new User + { + Id = savingUserId, + Premium = false, + MaxStorageGb = null, + Storage = 0 + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + // Mock GlobalSettings to indicate self-hosted + sutProvider.GetDependency().SelfHosted = true; + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, Arg.Any(), Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ReplaceAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate); + + // Assert - PricingClient should NOT be called for self-hosted + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + + // Assert - Attachment was uploaded successfully + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()); + } + private async Task AssertNoActionsAsync(SutProvider sutProvider) { await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default); From bab4750caa472449b82681bbc71f14be5db26cf9 Mon Sep 17 00:00:00 2001 From: Vincent Salucci <26154748+vincentsalucci@users.noreply.github.com> Date: Thu, 22 Jan 2026 11:23:18 -0600 Subject: [PATCH 12/12] chore: add feature flag definition, refs PM-26463 (#6882) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 10c68ddc42..9ffe199f1d 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -143,6 +143,7 @@ public static class FeatureFlagKeys public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration"; public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud"; public const string PremiumAccessQuery = "pm-29495-refactor-premium-interface"; + public const string RefactorMembersComponent = "pm-29503-refactor-members-inheritance"; /* Architecture */ public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1";
+ - +
-

+

© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

-
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This year's renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
+
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
-
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This year's renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
+
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.