diff --git a/bitwarden_license/src/Scim/appsettings.Production.json b/bitwarden_license/src/Scim/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/bitwarden_license/src/Scim/appsettings.Production.json +++ b/bitwarden_license/src/Scim/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/dev/verify_migrations.ps1 b/dev/verify_migrations.ps1 index ad0d34cef1..ce1754e684 100644 --- a/dev/verify_migrations.ps1 +++ b/dev/verify_migrations.ps1 @@ -5,12 +5,19 @@ Validates that new database migration files follow naming conventions and chronological order. .DESCRIPTION - This script validates migration files in util/Migrator/DbScripts/ to ensure: + This script validates migration files to ensure: + + For SQL migrations in util/Migrator/DbScripts/: 1. New migrations follow the naming format: YYYY-MM-DD_NN_Description.sql 2. New migrations are chronologically ordered (filename sorts after existing migrations) 3. Dates use leading zeros (e.g., 2025-01-05, not 2025-1-5) 4. A 2-digit sequence number is included (e.g., _00, _01) + For Entity Framework migrations in util/MySqlMigrations, util/PostgresMigrations, util/SqliteMigrations: + 1. New migrations follow the naming format: YYYYMMDDHHMMSS_Description.cs + 2. Each migration has both .cs and .Designer.cs files + 3. New migrations are chronologically ordered (timestamp sorts after existing migrations) + .PARAMETER BaseRef The base git reference to compare against (e.g., 'main', 'HEAD~1') @@ -58,75 +65,288 @@ $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$migrationPath/" # Find added migrations $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } +$sqlValidationFailed = $false + if ($addedMigrations.Count -eq 0) { - Write-Host "No new migration files added." - exit 0 + Write-Host "No new SQL migration files added." + Write-Host "" +} +else { + Write-Host "New SQL migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous SQL migrations found (initial commit?). Skipping chronological validation." + Write-Host "" + } + else { + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last SQL migration in base reference: $lastBaseMigration" + Write-Host "" + + # Required format regex: YYYY-MM-DD_NN_Description.sql + $formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Validate NEW migration filename format + if ($migrationName -notmatch $formatRegex) { + Write-Host "ERROR: Migration '$migrationName' does not match required format" + Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" + Write-Host " - YYYY: 4-digit year" + Write-Host " - MM: 2-digit month with leading zero (01-12)" + Write-Host " - DD: 2-digit day with leading zero (01-31)" + Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" + Write-Host "Example: 2025-01-15_00_MyMigration.sql" + $sqlValidationFailed = $true + continue + } + + # Compare migration name with last base migration (using ordinal string comparison) + if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" + $sqlValidationFailed = $true + } + else { + Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + } + } + + Write-Host "" + } + + if ($sqlValidationFailed) { + Write-Host "FAILED: One or more SQL migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new SQL migration files must:" + Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" + Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" + Write-Host " 4. Have a filename that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" + Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" + Write-Host " 3. Ensure the date is after $lastBaseMigration" + Write-Host "" + Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + } + else { + Write-Host "SUCCESS: All new SQL migrations are correctly named and in chronological order" + } + + Write-Host "" } -Write-Host "New migration files detected:" -$addedMigrations | ForEach-Object { Write-Host " $_" } +# =========================================================================================== +# Validate Entity Framework Migrations +# =========================================================================================== + +Write-Host "===================================================================" +Write-Host "Validating Entity Framework Migrations" +Write-Host "===================================================================" Write-Host "" -# Get the last migration from base reference -if ($baseMigrations.Count -eq 0) { - Write-Host "No previous migrations found (initial commit?). Skipping validation." - exit 0 -} +$efMigrationPaths = @( + @{ Path = "util/MySqlMigrations/Migrations"; Name = "MySQL" }, + @{ Path = "util/PostgresMigrations/Migrations"; Name = "Postgres" }, + @{ Path = "util/SqliteMigrations/Migrations"; Name = "SQLite" } +) -$lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) -Write-Host "Last migration in base reference: $lastBaseMigration" -Write-Host "" +$efValidationFailed = $false -# Required format regex: YYYY-MM-DD_NN_Description.sql -$formatRegex = '^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}_.+\.sql$' +foreach ($migrationPathInfo in $efMigrationPaths) { + $efPath = $migrationPathInfo.Path + $dbName = $migrationPathInfo.Name -$validationFailed = $false + Write-Host "-------------------------------------------------------------------" + Write-Host "Checking $dbName EF migrations in $efPath" + Write-Host "-------------------------------------------------------------------" + Write-Host "" -foreach ($migration in $addedMigrations) { - $migrationName = Split-Path -Leaf $migration + # Get list of migrations from base reference + try { + $baseMigrations = git ls-tree -r --name-only $BaseRef -- "$efPath/" 2>$null | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } + } + catch { + Write-Host "Warning: Could not retrieve $dbName migrations from base reference '$BaseRef'" + $baseMigrations = @() + } - # Validate NEW migration filename format - if ($migrationName -notmatch $formatRegex) { - Write-Host "ERROR: Migration '$migrationName' does not match required format" - Write-Host "Required format: YYYY-MM-DD_NN_Description.sql" - Write-Host " - YYYY: 4-digit year" - Write-Host " - MM: 2-digit month with leading zero (01-12)" - Write-Host " - DD: 2-digit day with leading zero (01-31)" - Write-Host " - NN: 2-digit sequence number (00, 01, 02, etc.)" - Write-Host "Example: 2025-01-15_00_MyMigration.sql" - $validationFailed = $true + # Get list of migrations from current reference + $currentMigrations = git ls-tree -r --name-only $CurrentRef -- "$efPath/" | Where-Object { $_ -like "*.cs" -and $_ -notlike "*DatabaseContextModelSnapshot.cs" } | Sort-Object + + # Find added migrations + $addedMigrations = $currentMigrations | Where-Object { $_ -notin $baseMigrations } + + if ($addedMigrations.Count -eq 0) { + Write-Host "No new $dbName EF migration files added." + Write-Host "" continue } - # Compare migration name with last base migration (using ordinal string comparison) - if ([string]::CompareOrdinal($migrationName, $lastBaseMigration) -lt 0) { - Write-Host "ERROR: New migration '$migrationName' is not chronologically after '$lastBaseMigration'" - $validationFailed = $true + Write-Host "New $dbName EF migration files detected:" + $addedMigrations | ForEach-Object { Write-Host " $_" } + Write-Host "" + + # Get the last migration from base reference + if ($baseMigrations.Count -eq 0) { + Write-Host "No previous $dbName migrations found. Skipping chronological validation." + Write-Host "" } else { - Write-Host "OK: '$migrationName' is chronologically after '$lastBaseMigration'" + $lastBaseMigration = Split-Path -Leaf ($baseMigrations | Select-Object -Last 1) + Write-Host "Last $dbName migration in base reference: $lastBaseMigration" + Write-Host "" } + + # Required format regex: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs + $efFormatRegex = '^[0-9]{14}_.+\.cs$' + + # Group migrations by base name (without .Designer.cs suffix) + $migrationGroups = @{} + $unmatchedFiles = @() + + foreach ($migration in $addedMigrations) { + $migrationName = Split-Path -Leaf $migration + + # Extract base name (remove .Designer.cs or .cs) + if ($migrationName -match '^([0-9]{14}_.+?)(?:\.Designer)?\.cs$') { + $baseName = $matches[1] + if (-not $migrationGroups.ContainsKey($baseName)) { + $migrationGroups[$baseName] = @() + } + $migrationGroups[$baseName] += $migrationName + } + else { + # Track files that don't match the expected pattern + $unmatchedFiles += $migrationName + } + } + + # Flag any files that don't match the expected pattern + if ($unmatchedFiles.Count -gt 0) { + Write-Host "ERROR: The following migration files do not match the required format:" + foreach ($unmatchedFile in $unmatchedFiles) { + Write-Host " - $unmatchedFile" + } + Write-Host "" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs or YYYYMMDDHHMMSS_Description.Designer.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host " - Description: Descriptive name using PascalCase" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" + Write-Host "" + $efValidationFailed = $true + } + + foreach ($baseName in $migrationGroups.Keys | Sort-Object) { + $files = $migrationGroups[$baseName] + + # Validate format + $mainFile = "$baseName.cs" + $designerFile = "$baseName.Designer.cs" + + if ($mainFile -notmatch $efFormatRegex) { + Write-Host "ERROR: Migration '$mainFile' does not match required format" + Write-Host "Required format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " - YYYYMMDDHHMMSS: 14-digit timestamp (Year, Month, Day, Hour, Minute, Second)" + Write-Host "Example: 20250115120000_AddNewFeature.cs" + $efValidationFailed = $true + continue + } + + # Check that both .cs and .Designer.cs files exist + $hasCsFile = $files -contains $mainFile + $hasDesignerFile = $files -contains $designerFile + + if (-not $hasCsFile) { + Write-Host "ERROR: Missing main migration file: $mainFile" + $efValidationFailed = $true + } + + if (-not $hasDesignerFile) { + Write-Host "ERROR: Missing designer file: $designerFile" + Write-Host "Each EF migration must have both a .cs and .Designer.cs file" + $efValidationFailed = $true + } + + if (-not $hasCsFile -or -not $hasDesignerFile) { + continue + } + + # Compare migration timestamp with last base migration (using ordinal string comparison) + if ($baseMigrations.Count -gt 0) { + if ([string]::CompareOrdinal($mainFile, $lastBaseMigration) -lt 0) { + Write-Host "ERROR: New migration '$mainFile' is not chronologically after '$lastBaseMigration'" + $efValidationFailed = $true + } + else { + Write-Host "OK: '$mainFile' is chronologically after '$lastBaseMigration'" + } + } + else { + Write-Host "OK: '$mainFile' (no previous migrations to compare)" + } + } + + Write-Host "" +} + +if ($efValidationFailed) { + Write-Host "FAILED: One or more EF migrations are incorrectly named or not in chronological order" + Write-Host "" + Write-Host "All new EF migration files must:" + Write-Host " 1. Follow the naming format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 2. Include both .cs and .Designer.cs files" + Write-Host " 3. Have a timestamp that sorts after the last migration in base" + Write-Host "" + Write-Host "To fix this issue:" + Write-Host " 1. Locate your migration file(s) in the respective Migrations directory" + Write-Host " 2. Ensure both .cs and .Designer.cs files exist" + Write-Host " 3. Rename to follow format: YYYYMMDDHHMMSS_Description.cs" + Write-Host " 4. Ensure the timestamp is after the last migration" + Write-Host "" + Write-Host "Example: 20250115120000_AddNewFeature.cs and 20250115120000_AddNewFeature.Designer.cs" +} +else { + Write-Host "SUCCESS: All new EF migrations are correctly named and in chronological order" } Write-Host "" +Write-Host "===================================================================" +Write-Host "Validation Summary" +Write-Host "===================================================================" + +if ($sqlValidationFailed -or $efValidationFailed) { + if ($sqlValidationFailed) { + Write-Host "❌ SQL migrations validation FAILED" + } + else { + Write-Host "✓ SQL migrations validation PASSED" + } + + if ($efValidationFailed) { + Write-Host "❌ EF migrations validation FAILED" + } + else { + Write-Host "✓ EF migrations validation PASSED" + } -if ($validationFailed) { - Write-Host "FAILED: One or more migrations are incorrectly named or not in chronological order" Write-Host "" - Write-Host "All new migration files must:" - Write-Host " 1. Follow the naming format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 2. Use leading zeros in dates (e.g., 2025-01-05, not 2025-1-5)" - Write-Host " 3. Include a 2-digit sequence number (e.g., _00, _01)" - Write-Host " 4. Have a filename that sorts after the last migration in base" - Write-Host "" - Write-Host "To fix this issue:" - Write-Host " 1. Locate your migration file(s) in util/Migrator/DbScripts/" - Write-Host " 2. Rename to follow format: YYYY-MM-DD_NN_Description.sql" - Write-Host " 3. Ensure the date is after $lastBaseMigration" - Write-Host "" - Write-Host "Example: 2025-01-15_00_AddNewFeature.sql" + Write-Host "OVERALL RESULT: FAILED" exit 1 } - -Write-Host "SUCCESS: All new migrations are correctly named and in chronological order" -exit 0 +else { + Write-Host "✓ SQL migrations validation PASSED" + Write-Host "✓ EF migrations validation PASSED" + Write-Host "" + Write-Host "OVERALL RESULT: SUCCESS" + exit 0 +} diff --git a/src/Admin/appsettings.Production.json b/src/Admin/appsettings.Production.json index 9f797f3111..1d852abfed 100644 --- a/src/Admin/appsettings.Production.json +++ b/src/Admin/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Api/appsettings.Production.json b/src/Api/appsettings.Production.json index d9efbcda12..a6578c08dc 100644 --- a/src/Api/appsettings.Production.json +++ b/src/Api/appsettings.Production.json @@ -23,11 +23,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs index c10368d8c0..9e20bd3191 100644 --- a/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs +++ b/src/Billing/Services/Implementations/SubscriptionUpdatedHandler.cs @@ -275,17 +275,24 @@ public class SubscriptionUpdatedHandler : ISubscriptionUpdatedHandler .PreviousAttributes .ToObject() as Subscription; + // Get all plan IDs that include Secrets Manager support to check if the organization has secret manager in the + // previous and/or current subscriptions. + var planIdsOfPlansWithSecretManager = (await _pricingClient.ListPlans()) + .Where(orgPlan => orgPlan.SupportsSecretsManager && orgPlan.SecretsManager.StripeSeatPlanId != null) + .Select(orgPlan => orgPlan.SecretsManager.StripeSeatPlanId) + .ToHashSet(); + // This being false doesn't necessarily mean that the organization doesn't subscribe to Secrets Manager. // If there are changes to any subscription item, Stripe sends every item in the subscription, both // changed and unchanged. var previousSubscriptionHasSecretsManager = previousSubscription?.Items is not null && previousSubscription.Items.Any( - previousSubscriptionItem => previousSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + previousSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(previousSubscriptionItem.Plan.Id)); var currentSubscriptionHasSecretsManager = subscription.Items.Any( - currentSubscriptionItem => currentSubscriptionItem.Plan.Id == plan.SecretsManager.StripeSeatPlanId); + currentSubscriptionItem => planIdsOfPlansWithSecretManager.Contains(currentSubscriptionItem.Plan.Id)); if (!previousSubscriptionHasSecretsManager || currentSubscriptionHasSecretsManager) { diff --git a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs index 004828dc48..ae2a76a7ce 100644 --- a/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs +++ b/src/Billing/Services/Implementations/UpcomingInvoiceHandler.cs @@ -627,7 +627,7 @@ public class UpcomingInvoiceHandler( { BaseMonthlyRenewalPrice = (premiumPlan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")), DiscountAmount = $"{coupon.PercentOff}%", - DiscountedMonthlyRenewalPrice = (discountedAnnualRenewalPrice / 12).ToString("C", new CultureInfo("en-US")) + DiscountedAnnualRenewalPrice = discountedAnnualRenewalPrice.ToString("C", new CultureInfo("en-US")) } }; diff --git a/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs new file mode 100644 index 0000000000..116992146f --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs @@ -0,0 +1,53 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Utilities; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.Collections; + +public static class CollectionUtils +{ + /// + /// Arranges Collection and CollectionUser objects to create default user collections. + /// + /// The organization ID. + /// The IDs for organization users who need default collections. + /// The encrypted string to use as the default collection name. + /// A tuple containing the collections and collection users. + public static (ICollection collections, ICollection collectionUsers) + BuildDefaultUserCollections(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) + { + var now = DateTime.UtcNow; + + var collectionUsers = new List(); + var collections = new List(); + + foreach (var orgUserId in organizationUserIds) + { + var collectionId = CoreHelpers.GenerateComb(); + + collections.Add(new Collection + { + Id = collectionId, + OrganizationId = organizationId, + Name = defaultCollectionName, + CreationDate = now, + RevisionDate = now, + Type = CollectionType.DefaultUserCollection, + DefaultUserCollectionEmail = null + + }); + + collectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = orgUserId, + ReadOnly = false, + HidePasswords = false, + Manage = true, + }); + } + + return (collections, collectionUsers); + } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs index 1b488677ae..0292381857 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUser/AutomaticallyConfirmOrganizationUserCommand.cs @@ -4,9 +4,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.OrganizationConfirmation; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; -using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -83,19 +81,10 @@ public class AutomaticallyConfirmOrganizationUserCommand(IOrganizationUserReposi return; } - await collectionRepository.CreateAsync( - new Collection - { - OrganizationId = request.Organization!.Id, - Name = request.DefaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }, - groups: null, - [new CollectionAccessSelection - { - Id = request.OrganizationUser!.Id, - Manage = true - }]); + await collectionRepository.CreateDefaultCollectionsAsync( + request.Organization!.Id, + [request.OrganizationUser!.Id], + request.DefaultUserCollectionName); } catch (Exception ex) { diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 0b82ac7ea4..02f3346ba6 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -14,7 +14,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -294,21 +293,10 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var defaultCollection = new Collection - { - OrganizationId = organizationUser.OrganizationId, - Name = defaultUserCollectionName, - Type = CollectionType.DefaultUserCollection - }; - var collectionUser = new CollectionAccessSelection - { - Id = organizationUser.Id, - ReadOnly = false, - HidePasswords = false, - Manage = true - }; - - await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]); + await _collectionRepository.CreateDefaultCollectionsAsync( + organizationUser.OrganizationId, + [organizationUser.Id], + defaultUserCollectionName); } /// @@ -339,7 +327,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - await _collectionRepository.UpsertDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); + await _collectionRepository.CreateDefaultCollectionsAsync(organizationId, eligibleOrganizationUserIds, defaultUserCollectionName); } /// diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs index ec42c8b402..c5b7314730 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/v1/RestoreOrganizationUserCommand.cs @@ -93,7 +93,7 @@ public class RestoreOrganizationUserCommand( .twoFactorIsEnabled; } - if (organization.PlanType == PlanType.Free) + if (organization.PlanType == PlanType.Free && organizationUser.UserId.HasValue) { await CheckUserForOtherFreeOrganizationOwnershipAsync(organizationUser); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs index 7a47baa65a..104a5751ff 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidator.cs @@ -57,14 +57,15 @@ public class OrganizationDataOwnershipPolicyValidator( var userOrgIds = requirements .Select(requirement => requirement.GetDefaultCollectionRequestOnPolicyEnable(policyUpdate.OrganizationId)) .Where(request => request.ShouldCreateDefaultCollection) - .Select(request => request.OrganizationUserId); + .Select(request => request.OrganizationUserId) + .ToList(); if (!userOrgIds.Any()) { return; } - await collectionRepository.UpsertDefaultCollectionsAsync( + await collectionRepository.CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, userOrgIds, defaultCollectionName); diff --git a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs index da7a77000b..d79923fdd1 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationRepository.cs @@ -21,7 +21,9 @@ public interface IOrganizationRepository : IRepository Task> GetOwnerEmailAddressesById(Guid organizationId); /// - /// Gets the organizations that have a verified domain matching the user's email domain. + /// Gets the organizations that have claimed the user's account. Currently, only one organization may claim a user. + /// This requires that the organization has claimed the user's domain and the user is an organization member. + /// It excludes invited members. /// Task> GetByVerifiedUserEmailDomainAsync(Guid userId); diff --git a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs index 176c77bf57..219f450f1d 100644 --- a/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs +++ b/src/Core/Billing/Premium/Commands/UpdatePremiumStorageCommand.cs @@ -2,6 +2,7 @@ using Bit.Core.Billing.Constants; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Core.Utilities; @@ -29,6 +30,7 @@ public interface IUpdatePremiumStorageCommand } public class UpdatePremiumStorageCommand( + IBraintreeService braintreeService, IStripeAdapter stripeAdapter, IUserService userService, IPricingClient pricingClient, @@ -49,7 +51,10 @@ public class UpdatePremiumStorageCommand( // Fetch all premium plans and the user's subscription to find which plan they're on var premiumPlans = await pricingClient.ListPremiumPlans(); - var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId); + var subscription = await stripeAdapter.GetSubscriptionAsync(user.GatewaySubscriptionId, new SubscriptionGetOptions + { + Expand = ["customer"] + }); // Find the password manager subscription item (seat, not storage) and match it to a plan var passwordManagerItem = subscription.Items.Data.FirstOrDefault(i => @@ -127,13 +132,41 @@ public class UpdatePremiumStorageCommand( }); } - var subscriptionUpdateOptions = new SubscriptionUpdateOptions - { - Items = subscriptionItemOptions, - ProrationBehavior = ProrationBehavior.AlwaysInvoice - }; + var usingPayPal = subscription.Customer.Metadata.ContainsKey(MetadataKeys.BraintreeCustomerId); - await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, subscriptionUpdateOptions); + if (usingPayPal) + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.CreateProrations + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + + var draftInvoice = await stripeAdapter.CreateInvoiceAsync(new InvoiceCreateOptions + { + Customer = subscription.CustomerId, + Subscription = subscription.Id, + AutoAdvance = false, + CollectionMethod = CollectionMethod.ChargeAutomatically + }); + + var finalizedInvoice = await stripeAdapter.FinalizeInvoiceAsync(draftInvoice.Id, + new InvoiceFinalizeOptions { AutoAdvance = false, Expand = ["customer"] }); + + await braintreeService.PayInvoice(new UserId(user.Id), finalizedInvoice); + } + else + { + var options = new SubscriptionUpdateOptions + { + Items = subscriptionItemOptions, + ProrationBehavior = ProrationBehavior.AlwaysInvoice + }; + + await stripeAdapter.UpdateSubscriptionAsync(subscription.Id, options); + } // Update the user's max storage user.MaxStorageGb = maxStorageGb; diff --git a/src/Core/Billing/Services/IStripeAdapter.cs b/src/Core/Billing/Services/IStripeAdapter.cs index 5ec732920e..12ea3d5a7c 100644 --- a/src/Core/Billing/Services/IStripeAdapter.cs +++ b/src/Core/Billing/Services/IStripeAdapter.cs @@ -24,6 +24,7 @@ public interface IStripeAdapter Task CancelSubscriptionAsync(string id, SubscriptionCancelOptions options = null); Task GetInvoiceAsync(string id, InvoiceGetOptions options); Task> ListInvoicesAsync(StripeInvoiceListOptions options); + Task CreateInvoiceAsync(InvoiceCreateOptions options); Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options); Task> SearchInvoiceAsync(InvoiceSearchOptions options); Task UpdateInvoiceAsync(string id, InvoiceUpdateOptions options); diff --git a/src/Core/Billing/Services/Implementations/StripeAdapter.cs b/src/Core/Billing/Services/Implementations/StripeAdapter.cs index cdc7645042..5b90500021 100644 --- a/src/Core/Billing/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Billing/Services/Implementations/StripeAdapter.cs @@ -116,6 +116,9 @@ public class StripeAdapter : IStripeAdapter return invoices; } + public Task CreateInvoiceAsync(InvoiceCreateOptions options) => + _invoiceService.CreateAsync(options); + public Task CreateInvoicePreviewAsync(InvoiceCreatePreviewOptions options) => _invoiceService.CreatePreviewAsync(options); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6f42778b6b..9ffe199f1d 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -143,6 +143,7 @@ public static class FeatureFlagKeys public const string BlockClaimedDomainAccountCreation = "pm-28297-block-uninvited-claimed-domain-registration"; public const string IncreaseBulkReinviteLimitForCloud = "pm-28251-increase-bulk-reinvite-limit-for-cloud"; public const string PremiumAccessQuery = "pm-29495-refactor-premium-interface"; + public const string RefactorMembersComponent = "pm-29503-refactor-members-inheritance"; /* Architecture */ public const string DesktopMigrationMilestone1 = "desktop-ui-migration-milestone-1"; @@ -162,7 +163,6 @@ public static class FeatureFlagKeys public const string MjmlWelcomeEmailTemplates = "pm-21741-mjml-welcome-email"; public const string OrganizationConfirmationEmail = "pm-28402-update-confirmed-to-org-email-template"; public const string MarketingInitiatedPremiumFlow = "pm-26140-marketing-initiated-premium-flow"; - public const string RedirectOnSsoRequired = "pm-1632-redirect-on-sso-required"; public const string PrefetchPasswordPrelogin = "pm-23801-prefetch-password-prelogin"; public const string PM27086_UpdateAuthenticationApisForInputPassword = "pm-27086-update-authentication-apis-for-input-password"; diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml index 092ae303de..06f60e7724 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/families-2019-renewal.mjml @@ -18,7 +18,7 @@ at {{BaseAnnualRenewalPrice}} + tax. - As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. + As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. diff --git a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml index a460442a7c..defec91f0e 100644 --- a/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml +++ b/src/Core/MailTemplates/Mjml/emails/Billing/Renewals/premium-renewal.mjml @@ -17,8 +17,8 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. - As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. + As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs index 227613999b..2d7c9edf35 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.html.hbs @@ -202,7 +202,7 @@ -
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +
As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
@@ -271,12 +271,12 @@ - + -
+
- +
@@ -364,8 +364,8 @@ - -
- + +
@@ -381,13 +381,13 @@
- +
+ - @@ -404,13 +404,13 @@ -
+ - +
- +
+ - @@ -427,13 +427,13 @@ -
+ - +
- +
+ - @@ -450,13 +450,13 @@ -
+ - +
- +
+ - @@ -473,13 +473,13 @@ -
+ - +
- +
+ - @@ -496,13 +496,13 @@ -
+ - +
- +
+ - @@ -519,13 +519,13 @@ -
+ - +
- +
+ - @@ -546,15 +546,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs index 88d64f9acf..9f40c88329 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Families2019Renewal/Families2019RenewalMailView.text.hbs @@ -1,7 +1,7 @@ Your Bitwarden Families subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually at {{BaseAnnualRenewalPrice}} + tax. -As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. +As a long time Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs index 4006c92a63..0798c7dbc8 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.cs @@ -5,7 +5,7 @@ namespace Bit.Core.Models.Mail.Billing.Renewal.Premium; public class PremiumRenewalMailView : BaseMailView { public required string BaseMonthlyRenewalPrice { get; set; } - public required string DiscountedMonthlyRenewalPrice { get; set; } + public required string DiscountedAnnualRenewalPrice { get; set; } public required string DiscountAmount { get; set; } } diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs index a6b2fda0f7..db76520eed 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.html.hbs @@ -201,8 +201,8 @@ @@ -270,12 +270,12 @@
+ - +
-

+

© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this

-
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. - This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually.
+
As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. + This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax.
- + -
+
- +
@@ -363,8 +363,8 @@ - -
- + +
@@ -380,13 +380,13 @@
- +
+ - @@ -403,13 +403,13 @@ -
+ - +
- +
+ - @@ -426,13 +426,13 @@ -
+ - +
- +
+ - @@ -449,13 +449,13 @@ -
+ - +
- +
+ - @@ -472,13 +472,13 @@ -
+ - +
- +
+ - @@ -495,13 +495,13 @@ -
+ - +
- +
+ - @@ -518,13 +518,13 @@ -
+ - +
- +
+ - @@ -545,15 +545,15 @@ diff --git a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs index 41300d0f96..4b79826f71 100644 --- a/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs +++ b/src/Core/Models/Mail/Billing/Renewal/Premium/PremiumRenewalMailView.text.hbs @@ -1,6 +1,6 @@ Your Bitwarden Premium subscription renews in 15 days. The price is updating to {{BaseMonthlyRenewalPrice}}/month, billed annually. -As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this renewal. -This renewal now will be {{DiscountedMonthlyRenewalPrice}}/month, billed annually. +As an existing Bitwarden customer, you will receive a one-time {{DiscountAmount}} loyalty discount for this year's renewal. +This renewal will now be billed annually at {{DiscountedAnnualRenewalPrice}} + tax. Questions? Contact support@bitwarden.com diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index f86147ca7d..3f3b71d2d5 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -64,11 +64,22 @@ public interface ICollectionRepository : IRepository IEnumerable users, IEnumerable groups); /// - /// Creates default user collections for the specified organization users if they do not already have one. + /// Creates default user collections for the specified organization users. + /// Filters internally to only create collections for users who don't already have one. /// /// The Organization ID. /// The Organization User IDs to create default collections for. /// The encrypted string to use as the default collection name. - /// - Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + + /// + /// Creates default user collections for the specified organization users using bulk insert operations. + /// Use this if you need to create collections for > ~1k users. + /// Filters internally to only create collections for users who don't already have one. + /// + /// The Organization ID. + /// The Organization User IDs to create default collections for. + /// The encrypted string to use as the default collection name. + Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); + } diff --git a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs index c545c8b35f..bd987bb396 100644 --- a/src/Core/Tools/SendFeatures/Services/SendValidationService.cs +++ b/src/Core/Tools/SendFeatures/Services/SendValidationService.cs @@ -6,6 +6,7 @@ using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Context; using Bit.Core.Exceptions; using Bit.Core.Repositories; @@ -27,6 +28,7 @@ public class SendValidationService : ISendValidationService private readonly GlobalSettings _globalSettings; private readonly ICurrentContext _currentContext; private readonly IPolicyRequirementQuery _policyRequirementQuery; + private readonly IPricingClient _pricingClient; @@ -38,7 +40,7 @@ public class SendValidationService : ISendValidationService IUserService userService, IPolicyRequirementQuery policyRequirementQuery, GlobalSettings globalSettings, - + IPricingClient pricingClient, ICurrentContext currentContext) { _userRepository = userRepository; @@ -48,6 +50,7 @@ public class SendValidationService : ISendValidationService _userService = userService; _policyRequirementQuery = policyRequirementQuery; _globalSettings = globalSettings; + _pricingClient = pricingClient; _currentContext = currentContext; } @@ -123,10 +126,19 @@ public class SendValidationService : ISendValidationService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - short limit = _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1; - storageBytesRemaining = user.StorageBytesRemaining(limit); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (send.OrganizationId.HasValue) diff --git a/src/Core/Utilities/LoggerFactoryExtensions.cs b/src/Core/Utilities/LoggerFactoryExtensions.cs index b950e30d5d..f3330f0792 100644 --- a/src/Core/Utilities/LoggerFactoryExtensions.cs +++ b/src/Core/Utilities/LoggerFactoryExtensions.cs @@ -1,4 +1,5 @@ -using Microsoft.AspNetCore.Hosting; +using System.Globalization; +using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; @@ -8,7 +9,7 @@ namespace Bit.Core.Utilities; public static class LoggerFactoryExtensions { /// - /// + /// /// /// /// @@ -21,10 +22,12 @@ public static class LoggerFactoryExtensions return; } + IConfiguration loggingConfiguration; + // If they have begun using the new settings location, use that if (!string.IsNullOrEmpty(context.Configuration["Logging:PathFormat"])) { - logging.AddFile(context.Configuration.GetSection("Logging")); + loggingConfiguration = context.Configuration.GetSection("Logging"); } else { @@ -40,28 +43,35 @@ public static class LoggerFactoryExtensions var projectName = loggingOptions.ProjectName ?? context.HostingEnvironment.ApplicationName; + string pathFormat; + if (loggingOptions.LogRollBySizeLimit.HasValue) { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "log.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}.log"); - - logging.AddFile( - pathFormat: pathFormat, - fileSizeLimitBytes: loggingOptions.LogRollBySizeLimit.Value - ); } else { - var pathFormat = loggingOptions.LogDirectoryByProject + pathFormat = loggingOptions.LogDirectoryByProject ? Path.Combine(loggingOptions.LogDirectory, projectName, "{Date}.txt") : Path.Combine(loggingOptions.LogDirectory, $"{projectName.ToLowerInvariant()}_{{Date}}.log"); - - logging.AddFile( - pathFormat: pathFormat - ); } + + // We want to rely on Serilog using the configuration section to have customization of the log levels + // so we make a custom configuration source for them based on the legacy values and allow overrides from + // the new location. + loggingConfiguration = new ConfigurationBuilder() + .AddInMemoryCollection(new Dictionary + { + {"PathFormat", pathFormat}, + {"FileSizeLimitBytes", loggingOptions.LogRollBySizeLimit?.ToString(CultureInfo.InvariantCulture)} + }) + .AddConfiguration(context.Configuration.GetSection("Logging")) + .Build(); } + + logging.AddFile(loggingConfiguration); }); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index fa2cfbb209..140399a37a 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -7,6 +7,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; +using Bit.Core.Billing.Pricing; using Bit.Core.Enums; using Bit.Core.Exceptions; using Bit.Core.Platform.Push; @@ -46,6 +47,7 @@ public class CipherService : ICipherService private readonly IPolicyRequirementQuery _policyRequirementQuery; private readonly IApplicationCacheService _applicationCacheService; private readonly IFeatureService _featureService; + private readonly IPricingClient _pricingClient; public CipherService( ICipherRepository cipherRepository, @@ -65,7 +67,8 @@ public class CipherService : ICipherService IGetCipherPermissionsForUserQuery getCipherPermissionsForUserQuery, IPolicyRequirementQuery policyRequirementQuery, IApplicationCacheService applicationCacheService, - IFeatureService featureService) + IFeatureService featureService, + IPricingClient pricingClient) { _cipherRepository = cipherRepository; _folderRepository = folderRepository; @@ -85,6 +88,7 @@ public class CipherService : ICipherService _policyRequirementQuery = policyRequirementQuery; _applicationCacheService = applicationCacheService; _featureService = featureService; + _pricingClient = pricingClient; } public async Task SaveAsync(Cipher cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, @@ -943,10 +947,19 @@ public class CipherService : ICipherService } else { - // Users that get access to file storage/premium from their organization get the default - // 1 GB max storage. - storageBytesRemaining = user.StorageBytesRemaining( - _globalSettings.SelfHosted ? Constants.SelfHostedMaxStorageGb : (short)1); + // Users that get access to file storage/premium from their organization get storage + // based on the current premium plan from the pricing service + short provided; + if (_globalSettings.SelfHosted) + { + provided = Constants.SelfHostedMaxStorageGb; + } + else + { + var premiumPlan = await _pricingClient.GetAvailablePremiumPlan(); + provided = (short)premiumPlan.Storage.Provided; + } + storageBytesRemaining = user.StorageBytesRemaining(provided); } } else if (cipher.OrganizationId.HasValue) diff --git a/src/Events/appsettings.Production.json b/src/Events/appsettings.Production.json index 010f02f8cd..9a10621264 100644 --- a/src/Events/appsettings.Production.json +++ b/src/Events/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/EventsProcessor/appsettings.Production.json b/src/EventsProcessor/appsettings.Production.json index 1cce4a9ed3..d57bf98b55 100644 --- a/src/EventsProcessor/appsettings.Production.json +++ b/src/EventsProcessor/appsettings.Production.json @@ -1,10 +1,8 @@ { "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Icons/appsettings.Production.json b/src/Icons/appsettings.Production.json index 828e8c61cc..19d21f7260 100644 --- a/src/Icons/appsettings.Production.json +++ b/src/Icons/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs index e07446d49f..289feebdb2 100644 --- a/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/BaseRequestValidator.cs @@ -4,7 +4,6 @@ using System.Security.Claims; using Bit.Core; -using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Entities; @@ -233,56 +232,14 @@ public abstract class BaseRequestValidator where T : class private async Task ValidateSsoAsync(T context, ValidatedTokenRequest request, CustomValidatorRequestContext validatorContext) { - // TODO: Clean up Feature Flag: Remove this if block: PM-28281 - if (!_featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired)) + var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); + if (ssoValid) { - validatorContext.SsoRequired = await RequireSsoLoginAsync(validatorContext.User, request.GrantType); - if (!validatorContext.SsoRequired) - { - return true; - } - - // Users without SSO requirement requesting 2FA recovery will be fast-forwarded through login and are - // presented with their 2FA management area as a reminder to re-evaluate their 2FA posture after recovery and - // review their new recovery token if desired. - // SSO users cannot be assumed to be authenticated, and must prove authentication with their IdP after recovery. - // As described in validation order determination, if TwoFactorRequired, the 2FA validation scheme will have been - // evaluated, and recovery will have been performed if requested. - // We will send a descriptive message in these cases so clients can give the appropriate feedback and redirect - // to /login. - if (validatorContext.TwoFactorRequired && - validatorContext.TwoFactorRecoveryRequested) - { - SetSsoResult(context, - new Dictionary - { - { - "ErrorModel", - new ErrorResponseModel( - "Two-factor recovery has been performed. SSO authentication is required.") - } - }); - return false; - } - - SetSsoResult(context, - new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }); - return false; + return true; } - else - { - var ssoValid = await _ssoRequestValidator.ValidateAsync(validatorContext.User, request, validatorContext); - if (ssoValid) - { - return true; - } - SetValidationErrorResult(context, validatorContext); - return ssoValid; - } + SetValidationErrorResult(context, validatorContext); + return ssoValid; } /// @@ -521,9 +478,6 @@ public abstract class BaseRequestValidator where T : class [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetTwoFactorResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] - protected abstract void SetSsoResult(T context, Dictionary customResponse); - [Obsolete("Consider using SetValidationErrorResult instead.")] protected abstract void SetErrorResult(T context, Dictionary customResponse); @@ -540,41 +494,6 @@ public abstract class BaseRequestValidator where T : class protected abstract ClaimsPrincipal GetSubject(T context); - /// - /// Check if the user is required to authenticate via SSO. If the user requires SSO, but they are - /// logging in using an API Key (client_credentials) then they are allowed to bypass the SSO requirement. - /// If the GrantType is authorization_code or client_credentials we know the user is trying to login - /// using the SSO flow so they are allowed to continue. - /// - /// user trying to login - /// magic string identifying the grant type requested - /// true if sso required; false if not required or already in process - [Obsolete( - "This method is deprecated and will be removed in future versions, PM-28281. Please use the SsoRequestValidator scheme instead.")] - private async Task RequireSsoLoginAsync(User user, string grantType) - { - if (grantType == "authorization_code" || grantType == "client_credentials") - { - // Already using SSO to authenticate, or logging-in via api key to skip SSO requirement - // allow to authenticate successfully - return false; - } - - // Check if user belongs to any organization with an active SSO policy - var ssoRequired = _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements) - ? (await PolicyRequirementQuery.GetAsync(user.Id)) - .SsoRequired - : await PolicyService.AnyPoliciesApplicableToUserAsync( - user.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - if (ssoRequired) - { - return true; - } - - // Default - SSO is not required - return false; - } - private async Task ResetFailedAuthDetailsAsync(User user) { // Early escape if db hit not necessary diff --git a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs index 38a4813ecd..2412c52308 100644 --- a/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/CustomTokenRequestValidator.cs @@ -194,17 +194,6 @@ public class CustomTokenRequestValidator : BaseRequestValidator customResponse) - { - Debug.Assert(context.Result is not null); - context.Result.Error = "invalid_grant"; - context.Result.ErrorDescription = "Sso authentication required."; - context.Result.IsError = true; - context.Result.CustomResponse = customResponse; - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(CustomTokenRequestValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs index ea2c021f63..8bfddf24f3 100644 --- a/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/ResourceOwnerPasswordValidator.cs @@ -152,14 +152,6 @@ public class ResourceOwnerPasswordValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetGrantValidationErrorResult instead.")] protected override void SetErrorResult(ResourceOwnerPasswordValidationContext context, Dictionary customResponse) diff --git a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs index e4cd60827e..1563831b81 100644 --- a/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs +++ b/src/Identity/IdentityServer/RequestValidators/WebAuthnGrantValidator.cs @@ -142,14 +142,6 @@ public class WebAuthnGrantValidator : BaseRequestValidator customResponse) - { - context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant, "Sso authentication required.", - customResponse); - } - [Obsolete("Consider using SetValidationErrorResult instead.")] protected override void SetErrorResult(ExtensionGrantValidationContext context, Dictionary customResponse) { diff --git a/src/Identity/appsettings.Production.json b/src/Identity/appsettings.Production.json index 4897a7d8b1..14471b5fb6 100644 --- a/src/Identity/appsettings.Production.json +++ b/src/Identity/appsettings.Production.json @@ -20,11 +20,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft.AspNetCore": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Infrastructure.Dapper/DapperHelpers.cs b/src/Infrastructure.Dapper/DapperHelpers.cs index 9a119e1e32..4384a6f752 100644 --- a/src/Infrastructure.Dapper/DapperHelpers.cs +++ b/src/Infrastructure.Dapper/DapperHelpers.cs @@ -160,6 +160,21 @@ public static class DapperHelpers return ids.ToArrayTVP("GuidId"); } + public static DataTable ToTwoGuidIdArrayTVP(this IEnumerable<(Guid id1, Guid id2)> values) + { + var table = new DataTable(); + table.SetTypeName("[dbo].[TwoGuidIdArray]"); + table.Columns.Add("Id1", typeof(Guid)); + table.Columns.Add("Id2", typeof(Guid)); + + foreach (var value in values) + { + table.Rows.Add(value.id1, value.id2); + } + + return table; + } + public static DataTable ToArrayTVP(this IEnumerable values, string columnName) { var table = new DataTable(); diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 9985b41d56..1531703427 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -1,6 +1,7 @@ using System.Data; using System.Diagnostics.CodeAnalysis; using System.Text.Json; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data; @@ -360,7 +361,45 @@ public class CollectionRepository : Repository, ICollectionRep } } - public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + { + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) + { + return; + } + + var organizationUserCollectionIds = organizationUserIds + .Select(ou => (ou, CoreHelpers.GenerateComb())) + .ToTwoGuidIdArrayTVP(); + + await using var connection = new SqlConnection(ConnectionString); + await connection.OpenAsync(); + await using var transaction = connection.BeginTransaction(); + + try + { + await connection.ExecuteAsync( + "[dbo].[Collection_CreateDefaultCollections]", + new + { + OrganizationId = organizationId, + DefaultCollectionName = defaultCollectionName, + OrganizationUserCollectionIds = organizationUserCollectionIds + }, + commandType: CommandType.StoredProcedure, + transaction: transaction); + + await transaction.CommitAsync(); + } + catch + { + await transaction.RollbackAsync(); + throw; + } + } + + public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -377,7 +416,8 @@ public class CollectionRepository : Repository, ICollectionRep var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); - var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + var (collections, collectionUsers) = + CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); if (!collectionUsers.Any() || !collections.Any()) { @@ -387,11 +427,11 @@ public class CollectionRepository : Repository, ICollectionRep await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); - transaction.Commit(); + await transaction.CommitAsync(); } catch { - transaction.Rollback(); + await transaction.RollbackAsync(); throw; } } @@ -421,40 +461,6 @@ public class CollectionRepository : Repository, ICollectionRep return organizationUserIds.ToHashSet(); } - private (List collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } - public class CollectionWithGroupsAndUsers : Collection { public CollectionWithGroupsAndUsers() { } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs index 88410facf5..93c8cd304c 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationRepository.cs @@ -325,7 +325,8 @@ public class OrganizationRepository : Repository od.OrganizationId == _organizationId && od.VerifiedDate != null && diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 5aa156d1f8..74150246b1 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -1,11 +1,10 @@ using AutoMapper; +using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; -using Bit.Core.Utilities; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; -using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; @@ -794,7 +793,7 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName) + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { organizationUserIds = organizationUserIds.ToList(); if (!organizationUserIds.Any()) @@ -808,15 +807,15 @@ public class CollectionRepository : Repository>(collections)); + await dbContext.CollectionUsers.AddRangeAsync(Mapper.Map>(collectionUsers)); await dbContext.SaveChangesAsync(); } @@ -844,37 +843,7 @@ public class CollectionRepository : Repository collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) - { - var collectionUsers = new List(); - var collections = new List(); - - foreach (var orgUserId in missingDefaultCollectionUserIds) - { - var collectionId = CoreHelpers.GenerateComb(); - - collections.Add(new Collection - { - Id = collectionId, - OrganizationId = organizationId, - Name = defaultCollectionName, - CreationDate = DateTime.UtcNow, - RevisionDate = DateTime.UtcNow, - Type = CollectionType.DefaultUserCollection, - DefaultUserCollectionEmail = null - - }); - - collectionUsers.Add(new CollectionUser - { - CollectionId = collectionId, - OrganizationUserId = orgUserId, - ReadOnly = false, - HidePasswords = false, - Manage = true, - }); - } - - return (collectionUsers, collections); - } + public Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) => + CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); } diff --git a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs index 7b67a63912..a0ee0376c0 100644 --- a/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs +++ b/src/Infrastructure.EntityFramework/Repositories/DatabaseContext.cs @@ -17,8 +17,6 @@ using Microsoft.EntityFrameworkCore.Storage.ValueConversion; using DP = Microsoft.AspNetCore.DataProtection; -#nullable enable - namespace Bit.Infrastructure.EntityFramework.Repositories; public class DatabaseContext : DbContext diff --git a/src/Notifications/appsettings.Production.json b/src/Notifications/appsettings.Production.json index 010f02f8cd..735c70e481 100644 --- a/src/Notifications/appsettings.Production.json +++ b/src/Notifications/appsettings.Production.json @@ -17,11 +17,9 @@ } }, "Logging": { - "IncludeScopes": false, "LogLevel": { - "Default": "Debug", - "System": "Information", - "Microsoft": "Information" + "Default": "Information", + "Microsoft": "Warning" }, "Console": { "IncludeScopes": true, diff --git a/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql new file mode 100644 index 0000000000..4e671bd1e4 --- /dev/null +++ b/src/Sql/dbo/AdminConsole/Stored Procedures/Collection_CreateDefaultCollections.sql @@ -0,0 +1,69 @@ +-- Creates default user collections for organization users +-- Filters out existing default collections at database level +CREATE PROCEDURE [dbo].[Collection_CreateDefaultCollections] + @OrganizationId UNIQUEIDENTIFIER, + @DefaultCollectionName VARCHAR(MAX), + @OrganizationUserCollectionIds AS [dbo].[TwoGuidIdArray] READONLY -- OrganizationUserId, CollectionId +AS +BEGIN + SET NOCOUNT ON + + DECLARE @Now DATETIME2(7) = GETUTCDATE() + + -- Filter to only users who don't have default collections + SELECT ids.Id1, ids.Id2 + INTO #FilteredIds + FROM @OrganizationUserCollectionIds ids + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionUser] cu + INNER JOIN [dbo].[Collection] c ON c.Id = cu.CollectionId + WHERE c.OrganizationId = @OrganizationId + AND c.[Type] = 1 -- CollectionType.DefaultUserCollection + AND cu.OrganizationUserId = ids.Id1 + ); + + -- Insert collections only for users who don't have default collections yet + INSERT INTO [dbo].[Collection] + ( + [Id], + [OrganizationId], + [Name], + [CreationDate], + [RevisionDate], + [Type], + [ExternalId], + [DefaultUserCollectionEmail] + ) + SELECT + ids.Id2, -- CollectionId + @OrganizationId, + @DefaultCollectionName, + @Now, + @Now, + 1, -- CollectionType.DefaultUserCollection + NULL, + NULL + FROM + #FilteredIds ids; + + -- Insert collection user mappings + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + ids.Id2, -- CollectionId + ids.Id1, -- OrganizationUserId + 0, -- ReadOnly = false + 0, -- HidePasswords = false + 1 -- Manage = true + FROM + #FilteredIds ids; + + DROP TABLE #FilteredIds; +END diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql index 64f3d81e08..4f781d2cc9 100644 --- a/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2.sql @@ -8,13 +8,14 @@ BEGIN SELECT * FROM [dbo].[OrganizationUserView] WHERE [OrganizationId] = @OrganizationId + AND [Status] != 0 -- Exclude invited users ), UserDomains AS ( SELECT U.[Id], U.[EmailDomain] FROM [dbo].[UserEmailDomainView] U WHERE EXISTS ( SELECT 1 - FROM [dbo].[OrganizationDomainView] OD + FROM [dbo].[OrganizationDomainView] OD WHERE OD.[OrganizationId] = @OrganizationId AND OD.[VerifiedDate] IS NOT NULL AND OD.[DomainName] = U.[EmailDomain] diff --git a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql index 583f548c8b..ee14c2c52a 100644 --- a/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql +++ b/src/Sql/dbo/Stored Procedures/Organization_ReadByClaimedUserEmailDomain.sql @@ -6,7 +6,7 @@ BEGIN WITH CTE_User AS ( SELECT - U.*, + U.[Id], SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain FROM dbo.[UserView] U WHERE U.[Id] = @UserId @@ -19,4 +19,5 @@ BEGIN WHERE OD.[VerifiedDate] IS NOT NULL AND CU.EmailDomain = OD.[DomainName] AND O.[Enabled] = 1 + AND OU.[Status] != 0 -- Exclude invited users END diff --git a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs index 182f09e163..2259d846b7 100644 --- a/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs +++ b/test/Billing.Test/Services/SubscriptionUpdatedHandlerTests.cs @@ -11,6 +11,7 @@ using Bit.Core.Billing.Pricing; using Bit.Core.OrganizationFeatures.OrganizationSponsorships.FamiliesForEnterprise.Interfaces; using Bit.Core.Repositories; using Bit.Core.Services; +using Bit.Core.Test.Billing.Mocks; using Bit.Core.Test.Billing.Mocks.Plans; using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; @@ -654,6 +655,8 @@ public class SubscriptionUpdatedHandlerTests var plan = new Enterprise2023Plan(true); _pricingClient.GetPlanOrThrow(organization.PlanType) .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); var parsedEvent = new Event { @@ -693,6 +696,92 @@ public class SubscriptionUpdatedHandlerTests await _stripeFacade.Received(1).DeleteCustomerDiscount(subscription.CustomerId); await _stripeFacade.Received(1).DeleteSubscriptionDiscount(subscription.Id); } + [Fact] + public async Task + HandleAsync_WhenUpgradingPlan_AndPreviousPlanHasSecretsManagerTrial_AndCurrentPlanHasSecretsManagerTrial_DoesNotRemovePasswordManagerCoupon() + { + // Arrange + var organizationId = Guid.NewGuid(); + var subscription = new Subscription + { + Id = "sub_123", + Status = StripeSubscriptionStatus.Active, + CustomerId = "cus_123", + Items = new StripeList + { + Data = + [ + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "2023-enterprise-org-seat-annually" } + }, + new SubscriptionItem + { + CurrentPeriodEnd = DateTime.UtcNow.AddDays(10), + Plan = new Plan { Id = "secrets-manager-enterprise-seat-annually" } + } + ] + }, + Customer = new Customer + { + Balance = 0, + Discount = new Discount { Coupon = new Coupon { Id = "sm-standalone" } } + }, + Discounts = [new Discount { Coupon = new Coupon { Id = "sm-standalone" } }], + Metadata = new Dictionary { { "organizationId", organizationId.ToString() } } + }; + + // Note: The organization plan is still the previous plan because the subscription is updated before the organization is updated + var organization = new Organization { Id = organizationId, PlanType = PlanType.TeamsAnnually2023 }; + + var plan = new Teams2023Plan(true); + _pricingClient.GetPlanOrThrow(organization.PlanType) + .Returns(plan); + _pricingClient.ListPlans() + .Returns(MockPlans.Plans); + + var parsedEvent = new Event + { + Data = new EventData + { + Object = subscription, + PreviousAttributes = JObject.FromObject(new + { + items = new + { + data = new[] + { + new { plan = new { id = "secrets-manager-teams-seat-annually" } }, + } + }, + Items = new StripeList + { + Data = + [ + new SubscriptionItem { Plan = new Stripe.Plan { Id = "secrets-manager-teams-seat-annually" } }, + ] + } + }) + } + }; + + _stripeEventService.GetSubscription(Arg.Any(), Arg.Any(), Arg.Any>()) + .Returns(subscription); + + _stripeEventUtilityService.GetIdsFromMetadata(Arg.Any>()) + .Returns(Tuple.Create(organizationId, null, null)); + + _organizationRepository.GetByIdAsync(organizationId) + .Returns(organization); + + // Act + await _sut.HandleAsync(parsedEvent); + + // Assert + await _stripeFacade.DidNotReceive().DeleteCustomerDiscount(subscription.CustomerId); + await _stripeFacade.DidNotReceive().DeleteSubscriptionDiscount(subscription.Id); + } [Theory] [MemberData(nameof(GetNonActiveSubscriptions))] diff --git a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs index 3b133c7d37..82d6c8acfd 100644 --- a/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs +++ b/test/Billing.Test/Services/UpcomingInvoiceHandlerTests.cs @@ -280,7 +280,7 @@ public class UpcomingInvoiceHandlerTests email.ToEmails.Contains("user@example.com") && email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && - email.View.DiscountedMonthlyRenewalPrice == (discountedPrice / 12).ToString("C", new CultureInfo("en-US")) && + email.View.DiscountedAnnualRenewalPrice == discountedPrice.ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == $"{coupon.PercentOff}%" )); } @@ -2436,7 +2436,7 @@ public class UpcomingInvoiceHandlerTests email.Subject == "Your Bitwarden Premium renewal is updating" && email.View.BaseMonthlyRenewalPrice == (plan.Seat.Price / 12).ToString("C", new CultureInfo("en-US")) && email.View.DiscountAmount == "30%" && - email.View.DiscountedMonthlyRenewalPrice == (expectedDiscountedPrice / 12).ToString("C", new CultureInfo("en-US")) + email.View.DiscountedAnnualRenewalPrice == expectedDiscountedPrice.ToString("C", new CultureInfo("en-US")) )); await _mailService.DidNotReceive().SendInvoiceUpcoming( diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs index 180750a9d0..252fb89c87 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/AutoConfirmUsers/AutomaticallyConfirmUsersCommandTests.cs @@ -10,7 +10,6 @@ using Bit.Core.AdminConsole.Utilities.v2; using Bit.Core.AdminConsole.Utilities.v2.Validation; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -204,14 +203,10 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.OrganizationId == organization.Id && - c.Name == defaultCollectionName && - c.Type == CollectionType.DefaultUserCollection), - Arg.Is>(groups => groups == null), - Arg.Is>(access => - access.FirstOrDefault(x => x.Id == organizationUser.Id && x.Manage) != null)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == organizationUser.Id), + defaultCollectionName); } [Theory] @@ -253,9 +248,7 @@ public class AutomaticallyConfirmUsersCommandTests await sutProvider.GetDependency() .DidNotReceive() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory] @@ -291,9 +284,7 @@ public class AutomaticallyConfirmUsersCommandTests var collectionException = new Exception("Collection creation failed"); sutProvider.GetDependency() - .CreateAsync(Arg.Any(), - Arg.Any>(), - Arg.Any>()) + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()) .ThrowsAsync(collectionException); // Act diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs index 65359b8304..6643f26eb5 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommandTests.cs @@ -13,7 +13,6 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; -using Bit.Core.Models.Data; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Platform.Push; using Bit.Core.Repositories; @@ -493,15 +492,10 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .Received(1) - .CreateAsync( - Arg.Is(c => - c.Name == collectionName && - c.OrganizationId == organization.Id && - c.Type == CollectionType.DefaultUserCollection), - Arg.Any>(), - Arg.Is>(cu => - cu.Single().Id == orgUser.Id && - cu.Single().Manage)); + .CreateDefaultCollectionsAsync( + organization.Id, + Arg.Is>(ids => ids.Single() == orgUser.Id), + collectionName); } [Theory, BitAutoData] @@ -522,7 +516,7 @@ public class ConfirmOrganizationUserCommandTests await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -539,24 +533,15 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { orgUser }); sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); - var policyDetails = new PolicyDetails - { - OrganizationId = org.Id, - OrganizationUserId = orgUser.Id, - IsProvider = false, - OrganizationUserStatus = orgUser.Status, - OrganizationUserType = orgUser.Type, - PolicyType = PolicyType.OrganizationDataOwnership - }; sutProvider.GetDependency() .GetAsync(orgUser.UserId!.Value) - .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails])); + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [])); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs index 4fa5e92abe..a75345a05d 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RestoreUser/RestoreOrganizationUserCommandTests.cs @@ -715,6 +715,39 @@ public class RestoreOrganizationUserCommandTests Arg.Is(x => x != OrganizationUserStatusType.Revoked)); } + [Theory, BitAutoData] + public async Task RestoreUser_InvitedUserInFreeOrganization_Success( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + organization.PlanType = PlanType.Free; + organizationUser.UserId = null; + organizationUser.Key = null; + organizationUser.Status = OrganizationUserStatusType.Revoked; + + RestoreUser_Setup(organization, owner, organizationUser, sutProvider); + sutProvider.GetDependency() + .GetOccupiedSeatCountByOrganizationIdAsync(organization.Id).Returns(new OrganizationSeatCounts + { + Sponsored = 0, + Users = 1 + }); + + await sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id); + + await sutProvider.GetDependency() + .Received(1) + .RestoreAsync(organizationUser.Id, OrganizationUserStatusType.Invited); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Restored); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .PushSyncOrgKeysAsync(Arg.Any()); + } + [Theory, BitAutoData] public async Task RestoreUsers_Success(Organization organization, [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs index 93cbde89ec..dd2f1d76e8 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/OrganizationDataOwnershipPolicyValidatorTests.cs @@ -38,7 +38,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -60,7 +60,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } [Theory, BitAutoData] @@ -86,7 +86,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceive() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( Arg.Any(), Arg.Any>(), Arg.Any()); @@ -172,10 +172,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -210,7 +210,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceive() - .UpsertDefaultCollectionsAsync(Arg.Any(), Arg.Any>(), Arg.Any()); + .CreateDefaultCollectionsBulkAsync(Arg.Any(), Arg.Any>(), Arg.Any()); } private static IPolicyRepository ArrangePolicyRepository(IEnumerable policyDetails) @@ -251,7 +251,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -273,7 +273,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } [Theory, BitAutoData] @@ -299,7 +299,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await collectionRepository .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( default, default, default); @@ -336,10 +336,10 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Act await sut.ExecutePostUpsertSideEffectAsync(policyRequest, postUpdatedPolicy, previousPolicyState); - // Assert + // Assert - Should call with all user IDs (repository does internal filtering) await collectionRepository .Received(1) - .UpsertDefaultCollectionsAsync( + .CreateDefaultCollectionsBulkAsync( policyUpdate.OrganizationId, Arg.Is>(ids => ids.Count() == 3), _defaultUserCollectionName); @@ -367,6 +367,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests // Assert await sutProvider.GetDependency() .DidNotReceiveWithAnyArgs() - .UpsertDefaultCollectionsAsync(default, default, default); + .CreateDefaultCollectionsBulkAsync(default, default, default); } } diff --git a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs index 7b9b68c757..cd9b323f9d 100644 --- a/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs +++ b/test/Core.Test/Billing/Premium/Commands/UpdatePremiumStorageCommandTests.cs @@ -1,6 +1,7 @@ using Bit.Core.Billing.Premium.Commands; using Bit.Core.Billing.Pricing; using Bit.Core.Billing.Services; +using Bit.Core.Billing.Subscriptions.Models; using Bit.Core.Entities; using Bit.Core.Services; using Bit.Test.Common.AutoFixture.Attributes; @@ -8,6 +9,7 @@ using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; +using static Bit.Core.Billing.Constants.StripeConstants; using PremiumPlan = Bit.Core.Billing.Pricing.Premium.Plan; using PremiumPurchasable = Bit.Core.Billing.Pricing.Premium.Purchasable; @@ -15,6 +17,7 @@ namespace Bit.Core.Test.Billing.Premium.Commands; public class UpdatePremiumStorageCommandTests { + private readonly IBraintreeService _braintreeService = Substitute.For(); private readonly IStripeAdapter _stripeAdapter = Substitute.For(); private readonly IUserService _userService = Substitute.For(); private readonly IPricingClient _pricingClient = Substitute.For(); @@ -33,13 +36,14 @@ public class UpdatePremiumStorageCommandTests _pricingClient.ListPremiumPlans().Returns([premiumPlan]); _command = new UpdatePremiumStorageCommand( + _braintreeService, _stripeAdapter, _userService, _pricingClient, Substitute.For>()); } - private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null) + private Subscription CreateMockSubscription(string subscriptionId, int? storageQuantity = null, bool isPayPal = false) { var items = new List { @@ -63,9 +67,17 @@ public class UpdatePremiumStorageCommandTests }); } + var customer = new Customer + { + Id = "cus_123", + Metadata = isPayPal ? new Dictionary { { MetadataKeys.BraintreeCustomerId, "braintree_123" } } : new Dictionary() + }; + return new Subscription { Id = subscriptionId, + CustomerId = "cus_123", + Customer = customer, Items = new StripeList { Data = items @@ -97,7 +109,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, -5); @@ -117,7 +129,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 100); @@ -154,7 +166,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -176,7 +188,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 4); @@ -185,7 +197,7 @@ public class UpdatePremiumStorageCommandTests Assert.True(result.IsT0); // Verify subscription was fetched but NOT updated - await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123"); + await _stripeAdapter.Received(1).GetSubscriptionAsync("sub_123", Arg.Any()); await _stripeAdapter.DidNotReceive().UpdateSubscriptionAsync(Arg.Any(), Arg.Any()); await _userService.DidNotReceive().SaveUserAsync(Arg.Any()); } @@ -200,7 +212,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -233,7 +245,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123"); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 9); @@ -262,7 +274,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 2); @@ -291,7 +303,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 9); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 0); @@ -320,7 +332,7 @@ public class UpdatePremiumStorageCommandTests user.GatewaySubscriptionId = "sub_123"; var subscription = CreateMockSubscription("sub_123", 4); - _stripeAdapter.GetSubscriptionAsync("sub_123").Returns(subscription); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); // Act var result = await _command.Run(user, 99); @@ -335,4 +347,200 @@ public class UpdatePremiumStorageCommandTests await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 100)); } + + [Theory, BitAutoData] + public async Task Run_IncreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 5; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 4, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with CreateProrations + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify draft invoice was created + await _stripeAdapter.Received(1).CreateInvoiceAsync( + Arg.Is(opts => + opts.Customer == "cus_123" && + opts.Subscription == "sub_123" && + opts.AutoAdvance == false && + opts.CollectionMethod == "charge_automatically")); + + // Verify invoice was finalized + await _stripeAdapter.Received(1).FinalizeInvoiceAsync( + "in_draft", + Arg.Is(opts => + opts.AutoAdvance == false && + opts.Expand.Contains("customer"))); + + // Verify Braintree payment was processed + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + // Verify user was saved + await _userService.Received(1).SaveUserAsync(Arg.Is(u => + u.Id == user.Id && + u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_AddStorageFromZero_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 1; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 9); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated with new storage item + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Price == "price_storage" && + opts.Items[0].Quantity == 9 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 10)); + } + + [Theory, BitAutoData] + public async Task Run_DecreaseStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 2L * 1024 * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 2); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription was updated + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Quantity == 2 && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 3)); + } + + [Theory, BitAutoData] + public async Task Run_RemoveAllAdditionalStorage_PayPal_Success(User user) + { + // Arrange + user.Premium = true; + user.MaxStorageGb = 10; + user.Storage = 500L * 1024 * 1024; + user.GatewaySubscriptionId = "sub_123"; + + var subscription = CreateMockSubscription("sub_123", 9, isPayPal: true); + _stripeAdapter.GetSubscriptionAsync("sub_123", Arg.Any()).Returns(subscription); + + var draftInvoice = new Invoice { Id = "in_draft" }; + _stripeAdapter.CreateInvoiceAsync(Arg.Any()).Returns(draftInvoice); + + var finalizedInvoice = new Invoice + { + Id = "in_finalized", + Customer = new Customer { Id = "cus_123" } + }; + _stripeAdapter.FinalizeInvoiceAsync("in_draft", Arg.Any()).Returns(finalizedInvoice); + + // Act + var result = await _command.Run(user, 0); + + // Assert + Assert.True(result.IsT0); + + // Verify subscription item was deleted + await _stripeAdapter.Received(1).UpdateSubscriptionAsync( + "sub_123", + Arg.Is(opts => + opts.Items.Count == 1 && + opts.Items[0].Id == "si_storage" && + opts.Items[0].Deleted == true && + opts.ProrationBehavior == "create_prorations")); + + // Verify invoice creation and payment flow + await _stripeAdapter.Received(1).CreateInvoiceAsync(Arg.Any()); + await _stripeAdapter.Received(1).FinalizeInvoiceAsync("in_draft", Arg.Any()); + await _braintreeService.Received(1).PayInvoice(Arg.Any(), finalizedInvoice); + + await _userService.Received(1).SaveUserAsync(Arg.Is(u => u.MaxStorageGb == 1)); + } } diff --git a/test/Core.Test/Tools/Services/SendValidationServiceTests.cs b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs new file mode 100644 index 0000000000..8adce1a29f --- /dev/null +++ b/test/Core.Test/Tools/Services/SendValidationServiceTests.cs @@ -0,0 +1,120 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; +using Bit.Core.Entities; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Core.Tools.Entities; +using Bit.Core.Tools.Enums; +using Bit.Core.Tools.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.Tools.Services; + +[SutProviderCustomize] +public class SendValidationServiceTests +{ + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgGrantedPremiumUser_UsesPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.Storage = 1024L * 1024L * 1024L; // 1 GB used + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = false; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + var premiumPlan = new Plan + { + Storage = new Purchasable { Provided = 5 } + }; + sutProvider.GetDependency().GetAvailablePremiumPlan().Returns(premiumPlan); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert + await sutProvider.GetDependency().Received(1).GetAvailablePremiumPlan(); + Assert.True(result > 0); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_IndividualPremium_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = true; + user.MaxStorageGb = 10; + user.EmailVerified = true; + + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for individual premium users + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_SelfHosted_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + User user) + { + // Arrange + send.UserId = user.Id; + send.OrganizationId = null; + send.Type = SendType.File; + user.Premium = false; + user.EmailVerified = true; + + sutProvider.GetDependency().SelfHosted = true; + sutProvider.GetDependency().GetByIdAsync(user.Id).Returns(user); + sutProvider.GetDependency().CanAccessPremium(user).Returns(true); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for self-hosted + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } + + [Theory, BitAutoData] + public async Task StorageRemainingForSendAsync_OrgSend_DoesNotCallPricingService( + SutProvider sutProvider, + Send send, + Organization org) + { + // Arrange + send.UserId = null; + send.OrganizationId = org.Id; + send.Type = SendType.File; + org.MaxStorageGb = 100; + + sutProvider.GetDependency().GetByIdAsync(org.Id).Returns(org); + + // Act + var result = await sutProvider.Sut.StorageRemainingForSendAsync(send); + + // Assert - should NOT call pricing service for org sends + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + } +} diff --git a/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs b/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs index 81311cb802..ffeb3fa2e7 100644 --- a/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs +++ b/test/Core.Test/Utilities/LoggerFactoryExtensionsTests.cs @@ -74,8 +74,7 @@ public class LoggerFactoryExtensionsTests logger.LogWarning("This is a test"); - // Writing to the file is buffered, give it a little time to flush - await Task.Delay(5); + await provider.DisposeAsync(); var logFile = Assert.Single(tempDir.EnumerateFiles("Logs/*.log")); @@ -90,13 +89,67 @@ public class LoggerFactoryExtensionsTests logFileContents ); } + + [Fact] + public async Task AddSerilogFileLogging_LegacyConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile() + { + await AssertSmallFileAsync((tempDir, config) => + { + config["GlobalSettings:LogDirectory"] = tempDir; + config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning"; + }); + } + + [Fact] + public async Task AddSerilogFileLogging_NewConfig_WithLevelCustomization_InfoLogs_DoNotFillUpFile() + { + await AssertSmallFileAsync((tempDir, config) => + { + config["Logging:PathFormat"] = Path.Combine(tempDir, "log.txt"); + config["Logging:LogLevel:Microsoft.AspNetCore"] = "Warning"; + }); + } + + private static async Task AssertSmallFileAsync(Action> configure) + { + using var tempDir = new TempDirectory(); + var config = new Dictionary(); + + configure(tempDir.Directory, config); + + var provider = GetServiceProvider(config, "Production"); + + var loggerFactory = provider.GetRequiredService(); + var microsoftLogger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Testing"); + + for (var i = 0; i < 100; i++) + { + microsoftLogger.LogInformation("Tons of useless information"); + } + + var otherLogger = loggerFactory.CreateLogger("Bitwarden"); + + for (var i = 0; i < 5; i++) + { + otherLogger.LogInformation("Mildly more useful information but not as frequent."); + } + + await provider.DisposeAsync(); + + var logFiles = Directory.EnumerateFiles(tempDir.Directory, "*.txt", SearchOption.AllDirectories); + var logFile = Assert.Single(logFiles); + + using var fr = File.OpenRead(logFile); + Assert.InRange(fr.Length, 0, 1024); + } + private static IEnumerable GetProviders(Dictionary initialData, string environment = "Production") { var provider = GetServiceProvider(initialData, environment); return provider.GetServices(); } - private static IServiceProvider GetServiceProvider(Dictionary initialData, string environment) + private static ServiceProvider GetServiceProvider(Dictionary initialData, string environment) { var config = new ConfigurationBuilder() .AddInMemoryCollection(initialData) diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 058c6f68ab..5fc92a9d39 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -6,6 +6,8 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements; using Bit.Core.AdminConsole.Services; using Bit.Core.Billing.Enums; +using Bit.Core.Billing.Pricing; +using Bit.Core.Billing.Pricing.Premium; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -2228,10 +2230,6 @@ public class CipherServiceTests .PushSyncCiphersAsync(deletingUserId); } - - - - [Theory] [OrganizationCipherCustomize] [BitAutoData] @@ -2387,6 +2385,186 @@ public class CipherServiceTests ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id)))); } + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_UsesStorageFromPricingClient( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium (Premium = false), but with org-granted premium access + var user = new User + { + Id = savingUserId, + Premium = false, // User does not have personal premium + MaxStorageGb = null, // No personal storage allocation + Storage = 0 // No storage used yet + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + // User has premium access through their organization + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + // Mock GlobalSettings to indicate cloud (not self-hosted) + sutProvider.GetDependency().SelfHosted = false; + + // Mock the PricingClient to return a premium plan with 1 GB of storage + var premiumPlan = new Plan + { + Name = "Premium", + Available = true, + Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 }, + Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 } + }; + sutProvider.GetDependency() + .GetAvailablePremiumPlan() + .Returns(premiumPlan); + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, Arg.Any(), Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ReplaceAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate); + + // Assert - PricingClient was called to get the premium plan storage + await sutProvider.GetDependency().Received(1).GetAvailablePremiumPlan(); + + // Assert - Attachment was uploaded successfully + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_ExceedsStorage_ThrowsBadRequest( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium, with org-granted access, but storage is full + var user = new User + { + Id = savingUserId, + Premium = false, + MaxStorageGb = null, + Storage = 1073741824 // 1 GB already used (equals the provided storage) + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency().SelfHosted = false; + + // Premium plan provides 1 GB of storage + var premiumPlan = new Plan + { + Name = "Premium", + Available = true, + Seat = new Purchasable { StripePriceId = "price_123", Price = 10, Provided = 1 }, + Storage = new Purchasable { StripePriceId = "price_456", Price = 4, Provided = 1 } + }; + sutProvider.GetDependency() + .GetAvailablePremiumPlan() + .Returns(premiumPlan); + + // Act & Assert - Should throw because storage is full + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate)); + Assert.Contains("Not enough storage available", exception.Message); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_UserWithOrgGrantedPremium_SelfHosted_UsesConstantStorage( + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Setup user WITHOUT personal premium, but with org-granted premium access + var user = new User + { + Id = savingUserId, + Premium = false, + MaxStorageGb = null, + Storage = 0 + }; + + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + // Mock GlobalSettings to indicate self-hosted + sutProvider.GetDependency().SelfHosted = true; + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, Arg.Any(), Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ReplaceAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + // Act + await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, cipher.RevisionDate); + + // Assert - PricingClient should NOT be called for self-hosted + await sutProvider.GetDependency().DidNotReceive().GetAvailablePremiumPlan(); + + // Assert - Attachment was uploaded successfully + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()); + } + private async Task AssertNoActionsAsync(SutProvider sutProvider) { await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default); diff --git a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs index 677382b138..4b6f681096 100644 --- a/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs +++ b/test/Identity.Test/IdentityServer/BaseRequestValidatorTests.cs @@ -18,6 +18,7 @@ using Bit.Core.Repositories; using Bit.Core.Services; using Bit.Core.Settings; using Bit.Identity.IdentityServer; +using Bit.Identity.IdentityServer.RequestValidationConstants; using Bit.Identity.IdentityServer.RequestValidators; using Bit.Identity.Test.Wrappers; using Bit.Test.Common.AutoFixture.Attributes; @@ -130,7 +131,7 @@ public class BaseRequestValidatorTests var logs = _logger.Collector.GetSnapshot(true); Assert.Contains(logs, l => l.Level == LogLevel.Warning && l.Message == "Failed login attempt. Is2FARequest: False IpAddress: "); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; Assert.Equal("Username or password is incorrect. Try again.", errorResponse.Message); } @@ -161,7 +162,11 @@ public class BaseRequestValidatorTests .ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(false)); - // 5 -> not legacy user + // 5 -> SSO not required + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 6 -> not legacy user _userService.IsLegacyUser(Arg.Any()) .Returns(false); @@ -203,6 +208,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); + // 6 -> SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 7 -> setup user account keys _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -262,6 +272,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); + // 6 -> SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 7 -> setup user account keys _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -326,6 +341,9 @@ public class BaseRequestValidatorTests { "TwoFactorProviders2", new Dictionary { { "Email", null } } } })); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + // Act await _sut.ValidateAsync(context); @@ -368,6 +386,10 @@ public class BaseRequestValidatorTests .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Email, "invalid_token") .Returns(Task.FromResult(false)); + // 5 -> set up SSO required verification to succeed + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + // Act await _sut.ValidateAsync(context); @@ -396,21 +418,25 @@ public class BaseRequestValidatorTests // 1 -> initial validation passes _sut.isValid = true; - // 2 -> set up 2FA as required + // 2 -> set up SSO required verification to succeed + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 3 -> set up 2FA as required _twoFactorAuthenticationValidator .RequiresTwoFactorAsync(Arg.Any(), tokenRequest) .Returns(Task.FromResult(new Tuple(true, null))); - // 3 -> provide invalid remember token (remember token expired) + // 4 -> provide invalid remember token (remember token expired) tokenRequest.Raw["TwoFactorToken"] = "expired_remember_token"; tokenRequest.Raw["TwoFactorProvider"] = "5"; // Remember provider - // 4 -> set up remember token verification to fail + // 5 -> set up remember token verification to fail _twoFactorAuthenticationValidator .VerifyTwoFactorAsync(user, null, TwoFactorProviderType.Remember, "expired_remember_token") .Returns(Task.FromResult(false)); - // 5 -> set up dummy BuildTwoFactorResultAsync + // 6 -> set up dummy BuildTwoFactorResultAsync var twoFactorResultDict = new Dictionary { { "TwoFactorProviders", new[] { "0", "1" } }, @@ -446,6 +472,19 @@ public class BaseRequestValidatorTests GrantValidationResult grantResult) { // Arrange + + // SsoRequestValidator sets custom response + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -454,13 +493,17 @@ public class BaseRequestValidatorTests Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(false)); + // Act await _sut.ValidateAsync(context); // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); } // Test grantTypes with RequireSsoPolicyRequirement when feature flag is enabled @@ -477,6 +520,20 @@ public class BaseRequestValidatorTests { // Arrange _featureService.IsEnabled(FeatureFlagKeys.PolicyRequirements).Returns(true); + + // SsoRequestValidator sets custom response with organization identifier + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + { CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" } + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -485,6 +542,10 @@ public class BaseRequestValidatorTests var requirement = new RequireSsoPolicyRequirement { SsoRequired = true }; _policyRequirementQuery.GetAsync(Arg.Any()).Returns(requirement); + // Mock the SSO validator to return false + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(false)); + // Act await _sut.ValidateAsync(context); @@ -492,8 +553,9 @@ public class BaseRequestValidatorTests await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); } [Theory] @@ -519,6 +581,10 @@ public class BaseRequestValidatorTests var requirement = new RequireSsoPolicyRequirement { SsoRequired = false }; _policyRequirementQuery.GetAsync(Arg.Any()).Returns(requirement); + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -561,6 +627,11 @@ public class BaseRequestValidatorTests _policyService.AnyPoliciesApplicableToUserAsync( Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) .Returns(Task.FromResult(false)); + + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -603,6 +674,10 @@ public class BaseRequestValidatorTests context.ValidatedTokenRequest.GrantType = grantType; + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -652,13 +727,15 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; var expectedMessage = "Legacy encryption without a userkey is no longer supported. To recover your account, please contact support"; Assert.Equal(expectedMessage, errorResponse.Message); @@ -694,6 +771,10 @@ public class BaseRequestValidatorTests var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; + // SSO validation passes + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) @@ -760,6 +841,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -833,6 +916,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -877,6 +962,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -921,6 +1008,8 @@ public class BaseRequestValidatorTests .Returns(Task.FromResult(new Tuple(false, null))); _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) .Returns(Task.FromResult(true)); + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); // Act await _sut.ValidateAsync(context); @@ -950,6 +1039,19 @@ public class BaseRequestValidatorTests GrantValidationResult grantResult) { // Arrange + + // SsoRequestValidator sets custom response + requestContext.ValidationErrorResult = new ValidationResult + { + IsError = true, + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription + }; + requestContext.CustomResponse = new Dictionary + { + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + }; + var context = CreateContext(tokenRequest, requestContext, grantResult); var user = requestContext.User; @@ -984,12 +1086,12 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError, "Authentication should fail - SSO required after recovery"); - - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + Assert.NotNull(context.GrantResult.CustomResponse); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; // Recovery succeeds, then SSO blocks with descriptive message Assert.Equal( - "Two-factor recovery has been performed. SSO authentication is required.", + SsoConstants.RequestErrors.SsoRequiredDescription, errorResponse.Message); // Verify recovery was marked @@ -1050,7 +1152,7 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError, "Authentication should fail - invalid recovery code"); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; + var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; // 2FA is checked first (due to recovery code request), fails with 2FA error Assert.Equal( @@ -1132,7 +1234,11 @@ public class BaseRequestValidatorTests _userService.IsLegacyUser(Arg.Any()) .Returns(false); - // 8. Setup user account keys for successful login response + // 8. SSO is not required + _ssoRequestValidator.ValidateAsync(requestContext.User, tokenRequest, requestContext) + .Returns(Task.FromResult(true)); + + // 9. Setup user account keys for successful login response _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData { PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( @@ -1161,179 +1267,18 @@ public class BaseRequestValidatorTests } /// - /// Tests that when RedirectOnSsoRequired is DISABLED, the legacy SSO validation path is used. - /// This validates the deprecated RequireSsoLoginAsync method is called and SSO requirement - /// is checked using the old PolicyService.AnyPoliciesApplicableToUserAsync approach. + /// Tests that when SSO validation returns a custom response, (e.g., with organization identifier), + /// that custom response is properly propagated to the result. /// [Theory] [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_UsesLegacySsoValidation( + public async Task ValidateAsync_SsoRequired_PropagatesCustomResponse( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - - // SSO is required via legacy path - _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - Assert.Equal("SSO authentication is required.", errorResponse.Message); - - // Verify legacy path was used - await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( - requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - - // Verify new SsoRequestValidator was NOT called - await _ssoRequestValidator.DidNotReceive().ValidateAsync( - Arg.Any(), Arg.Any(), Arg.Any()); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED, the new ISsoRequestValidator is used - /// instead of the legacy RequireSsoLoginAsync method. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_UsesNewSsoRequestValidator( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - - // Configure SsoRequestValidator to indicate SSO is required - _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) - .Returns(Task.FromResult(false)); // false = SSO required - - // Set up the ValidationErrorResult that SsoRequestValidator would set - requestContext.ValidationErrorResult = new ValidationResult - { - IsError = true, - Error = "sso_required", - ErrorDescription = "SSO authentication is required." - }; - requestContext.CustomResponse = new Dictionary - { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") } - }; - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - - // Verify new SsoRequestValidator was called - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - requestContext); - - // Verify legacy path was NOT used - await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), Arg.Any(), Arg.Any()); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and SSO is NOT required, - /// authentication continues successfully through the new validation path. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_SsoNotRequired_SuccessfulLogin( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - tokenRequest.GrantType = OidcConstants.GrantTypes.Password; - tokenRequest.ClientId = "web"; - - // SsoRequestValidator returns true (SSO not required) - _ssoRequestValidator.ValidateAsync( - Arg.Any(), - Arg.Any(), - Arg.Any()) - .Returns(Task.FromResult(true)); - - // No 2FA required - _twoFactorAuthenticationValidator.RequiresTwoFactorAsync(requestContext.User, tokenRequest) - .Returns(Task.FromResult(new Tuple(false, null))); - - // Device validation passes - _deviceValidator.ValidateRequestDeviceAsync(tokenRequest, requestContext) - .Returns(Task.FromResult(true)); - - // User is not legacy - _userService.IsLegacyUser(Arg.Any()).Returns(false); - - _userAccountKeysQuery.Run(Arg.Any()).Returns(new UserAccountKeysData - { - PublicKeyEncryptionKeyPairData = new PublicKeyEncryptionKeyPairData( - "test-private-key", - "test-public-key" - ) - }); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.False(context.GrantResult.IsError); - await _eventService.Received(1).LogUserEventAsync(requestContext.User.Id, EventType.User_LoggedIn); - - // Verify new validator was used - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - requestContext); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and SSO validation returns a custom response - /// (e.g., with organization identifier), that custom response is properly propagated to the result. - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_PropagatesCustomResponse( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); _sut.isValid = true; tokenRequest.GrantType = OidcConstants.GrantTypes.Password; @@ -1342,13 +1287,13 @@ public class BaseRequestValidatorTests requestContext.ValidationErrorResult = new ValidationResult { IsError = true, - Error = "sso_required", - ErrorDescription = "SSO authentication is required." + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoRequiredDescription }; requestContext.CustomResponse = new Dictionary { - { "ErrorModel", new ErrorResponseModel("SSO authentication is required.") }, - { "SsoOrganizationIdentifier", "test-org-identifier" } + { CustomResponseConstants.ResponseKeys.ErrorModel, new ErrorResponseModel(SsoConstants.RequestErrors.SsoRequiredDescription) }, + { CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, "test-org-identifier" } }; var context = CreateContext(tokenRequest, requestContext, grantResult); @@ -1365,77 +1310,24 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError); Assert.NotNull(context.GrantResult.CustomResponse); - Assert.Contains("SsoOrganizationIdentifier", context.CustomValidatorRequestContext.CustomResponse); + Assert.Contains(CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier, context.CustomValidatorRequestContext.CustomResponse); Assert.Equal("test-org-identifier", - context.CustomValidatorRequestContext.CustomResponse["SsoOrganizationIdentifier"]); + context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.SsoOrganizationIdentifier]); } /// - /// Tests that when RedirectOnSsoRequired is DISABLED and a user with 2FA recovery completes recovery, - /// but SSO is required, the legacy error message is returned (without the recovery-specific message). - /// - [Theory] - [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Disabled_RecoveryWithSso_LegacyMessage( - [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, - [AuthFixtures.CustomValidatorRequestContext] - CustomValidatorRequestContext requestContext, - GrantValidationResult grantResult) - { - // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(false); - - var context = CreateContext(tokenRequest, requestContext, grantResult); - _sut.isValid = true; - - // Recovery code scenario - tokenRequest.Raw["TwoFactorProvider"] = ((int)TwoFactorProviderType.RecoveryCode).ToString(); - tokenRequest.Raw["TwoFactorToken"] = "valid-recovery-code"; - - // 2FA with recovery - _twoFactorAuthenticationValidator - .RequiresTwoFactorAsync(requestContext.User, tokenRequest) - .Returns(Task.FromResult(new Tuple(true, null))); - - _twoFactorAuthenticationValidator - .VerifyTwoFactorAsync(requestContext.User, null, TwoFactorProviderType.RecoveryCode, "valid-recovery-code") - .Returns(Task.FromResult(true)); - - // SSO is required (legacy check) - _policyService.AnyPoliciesApplicableToUserAsync( - Arg.Any(), PolicyType.RequireSso, OrganizationUserStatusType.Confirmed) - .Returns(Task.FromResult(true)); - - // Act - await _sut.ValidateAsync(context); - - // Assert - Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.GrantResult.CustomResponse["ErrorModel"]; - - // Legacy behavior: recovery-specific message IS shown even without RedirectOnSsoRequired - Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); - - // But legacy validation path was used - await _policyService.Received(1).AnyPoliciesApplicableToUserAsync( - requestContext.User.Id, PolicyType.RequireSso, OrganizationUserStatusType.Confirmed); - } - - /// - /// Tests that when RedirectOnSsoRequired is ENABLED and recovery code is used for SSO-required user, + /// Tests that when a recovery code is used for SSO-required user, /// the SsoRequestValidator provides the recovery-specific error message. /// [Theory] [BitAutoData] - public async Task ValidateAsync_RedirectOnSsoRequired_Enabled_RecoveryWithSso_NewValidatorMessage( + public async Task ValidateAsync_RecoveryWithSso_CorrectValidatorMessage( [AuthFixtures.ValidatedTokenRequest] ValidatedTokenRequest tokenRequest, [AuthFixtures.CustomValidatorRequestContext] CustomValidatorRequestContext requestContext, GrantValidationResult grantResult) { // Arrange - _featureService.IsEnabled(FeatureFlagKeys.RedirectOnSsoRequired).Returns(true); - var context = CreateContext(tokenRequest, requestContext, grantResult); _sut.isValid = true; @@ -1457,14 +1349,14 @@ public class BaseRequestValidatorTests requestContext.ValidationErrorResult = new ValidationResult { IsError = true, - Error = "sso_required", - ErrorDescription = "Two-factor recovery has been performed. SSO authentication is required." + Error = SsoConstants.RequestErrors.SsoRequired, + ErrorDescription = SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription }; requestContext.CustomResponse = new Dictionary { { - "ErrorModel", - new ErrorResponseModel("Two-factor recovery has been performed. SSO authentication is required.") + CustomResponseConstants.ResponseKeys.ErrorModel, + new ErrorResponseModel(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription) } }; @@ -1479,18 +1371,8 @@ public class BaseRequestValidatorTests // Assert Assert.True(context.GrantResult.IsError); - var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse["ErrorModel"]; - Assert.Equal("Two-factor recovery has been performed. SSO authentication is required.", errorResponse.Message); - - // Verify new validator was used - await _ssoRequestValidator.Received(1).ValidateAsync( - requestContext.User, - tokenRequest, - Arg.Is(ctx => ctx.TwoFactorRecoveryRequested)); - - // Verify legacy path was NOT used - await _policyService.DidNotReceive().AnyPoliciesApplicableToUserAsync( - Arg.Any(), Arg.Any(), Arg.Any()); + var errorResponse = (ErrorResponseModel)context.CustomValidatorRequestContext.CustomResponse[CustomResponseConstants.ResponseKeys.ErrorModel]; + Assert.Equal(SsoConstants.RequestErrors.SsoTwoFactorRecoveryDescription, errorResponse.Message); } private BaseRequestValidationContextFake CreateContext( diff --git a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs index b336e4c3c1..ac27c55466 100644 --- a/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs +++ b/test/Identity.Test/Wrappers/BaseRequestValidatorTestWrapper.cs @@ -111,15 +111,6 @@ IBaseRequestValidatorTestWrapper context.GrantResult = new GrantValidationResult(TokenRequestErrors.InvalidGrant, customResponse: customResponse); } - [Obsolete] - protected override void SetSsoResult( - BaseRequestValidationContextFake context, - Dictionary customResponse) - { - context.GrantResult = new GrantValidationResult( - TokenRequestErrors.InvalidGrant, "Sso authentication required.", customResponse); - } - protected override Task SetSuccessResult( BaseRequestValidationContextFake context, User user, diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs new file mode 100644 index 0000000000..712ad7d62e --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsBulkTests.cs @@ -0,0 +1,53 @@ +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; + + +public class CreateDefaultCollectionsBulkAsyncTests +{ + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_CreatesDefaultCollections_Success( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsBulkAsync_IgnoresAllExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers( + collectionRepository.CreateDefaultCollectionsBulkAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs similarity index 69% rename from test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs rename to test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs index 64dffa473f..0fb4a5b446 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/UpsertDefaultCollectionsTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsSharedTests.cs @@ -6,10 +6,14 @@ using Xunit; namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; -public class UpsertDefaultCollectionsTests +/// +/// Shared tests for CreateDefaultCollections methods - both bulk and non-bulk implementations, +/// as they share the same behavior. Both test suites call the tests in this class. +/// +public static class CreateDefaultCollectionsSharedTests { - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection( + public static async Task CreatesDefaultCollections_Success( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -21,14 +25,13 @@ public class UpsertDefaultCollectionsTests var resultOrganizationUsers = await Task.WhenAll( CreateUserForOrgAsync(userRepository, organizationUserRepository, organization), CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) - ); + ); - - var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); // Assert await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); @@ -36,8 +39,8 @@ public class UpsertDefaultCollectionsTests await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); } - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist( + public static async Task CreatesForNewUsersOnly_AndIgnoresExistingUsers( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -51,31 +54,30 @@ public class UpsertDefaultCollectionsTests CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) ); - var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id); + var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; + await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers); - await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers); - - var newOrganizationUsers = new List() + var newOrganizationUsers = new List { await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) }; - var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers); - var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers).ToList(); + var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id).ToList(); // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); // Assert - await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id); + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, affectedOrgUsers, organization.Id); await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers); } - [Theory, DatabaseData] - public async Task UpsertDefaultCollectionsAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne( + public static async Task IgnoresAllExistingUsers( + Func, string, Task> createDefaultCollectionsFunc, IOrganizationRepository organizationRepository, IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, @@ -89,26 +91,29 @@ public class UpsertDefaultCollectionsTests CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) ); - var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id).ToList(); var defaultCollectionName = $"default-name-{organization.Id}"; + await CreateUsersWithExistingDefaultCollectionsAsync(createDefaultCollectionsFunc, collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); - await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); + // Act - Try to create again, should silently filter and not create duplicates + await createDefaultCollectionsFunc(organization.Id, affectedOrgUserIds, defaultCollectionName); - // Act - await collectionRepository.UpsertDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); - - // Assert + // Assert - Original collections should remain unchanged, still only one per user await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); } - private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository, - Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName, + private static async Task CreateUsersWithExistingDefaultCollectionsAsync( + Func, string, Task> createDefaultCollectionsFunc, + ICollectionRepository collectionRepository, + Guid organizationId, + IEnumerable affectedOrgUserIds, + string defaultCollectionName, OrganizationUser[] resultOrganizationUsers) { - await collectionRepository.UpsertDefaultCollectionsAsync(organizationId, affectedOrgUserIds, defaultCollectionName); + await createDefaultCollectionsFunc(organizationId, affectedOrgUserIds, defaultCollectionName); await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId); } @@ -131,7 +136,6 @@ public class UpsertDefaultCollectionsTests private static async Task CreateUserForOrgAsync(IUserRepository userRepository, IOrganizationUserRepository organizationUserRepository, Organization organization) { - var user = await userRepository.CreateTestUserAsync(); var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user); diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs new file mode 100644 index 0000000000..bd894e9ca5 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs @@ -0,0 +1,52 @@ +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; + +public class CreateDefaultCollectionsAsyncTests +{ + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_CreatesDefaultCollections_Success( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesDefaultCollections_Success( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_CreatesForNewUsersOnly_AndIgnoresExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.CreatesForNewUsersOnly_AndIgnoresExistingUsers( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } + + [Theory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_IgnoresAllExistingUsers( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + await CreateDefaultCollectionsSharedTests.IgnoresAllExistingUsers( + collectionRepository.CreateDefaultCollectionsAsync, + organizationRepository, + userRepository, + organizationUserRepository, + collectionRepository); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs new file mode 100644 index 0000000000..6dd7aafca4 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepository/GetByVerifiedUserEmailDomainAsyncTests.cs @@ -0,0 +1,335 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationRepository; + +public class GetByVerifiedUserEmailDomainAsyncTests +{ + [Theory, DatabaseData] + public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user1 = await userRepository.CreateAsync(new User + { + Name = "Test User 1", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user2 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@x-{domainName}", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user3 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@{domainName}.example.com", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3); + + var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id); + var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id); + var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id); + + Assert.NotEmpty(user1Response); + Assert.Equal(organization.Id, user1Response.First().Id); + Assert.Empty(user2Response); + Assert.Empty(user3Response); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.Empty(result); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization1 = await organizationRepository.CreateTestOrganizationAsync(); + var organization2 = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain1 = new OrganizationDomain + { + OrganizationId = organization1.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain1.SetNextRunDate(12); + organizationDomain1.SetJobRunCount(); + organizationDomain1.SetVerifiedDate(); + await organizationDomainRepository.CreateAsync(organizationDomain1); + + var organizationDomain2 = new OrganizationDomain + { + OrganizationId = organization2.Id, + DomainName = domainName, + Txt = "btw+67890", + }; + organizationDomain2.SetNextRunDate(12); + organizationDomain2.SetJobRunCount(); + organizationDomain2.SetVerifiedDate(); + await organizationDomainRepository.CreateAsync(organizationDomain2); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization1, user); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization2, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.Equal(2, result.Count); + Assert.Contains(result, org => org.Id == organization1.Id); + Assert.Contains(result, org => org.Id == organization2.Id); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty( + IOrganizationRepository organizationRepository) + { + var nonExistentUserId = Guid.NewGuid(); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId); + + Assert.Empty(result); + } + + /// + /// Tests an edge case where some invited users are created linked to a UserId. + /// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly + /// exclude such users from the results without relying on the inner join only. + /// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused + /// any issues to date and we want to minimize edge cases. + /// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405 + /// + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithInvitedUserWithUserId_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + // Create invited user with matching email domain but UserId set (edge case) + await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = user.Id, + Email = user.Email, + Status = OrganizationUserStatusType.Invited, + }); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + // Invited users should be excluded even if they have UserId set + Assert.Empty(result); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithAcceptedUser_ReturnsOrganization( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateAcceptedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.NotEmpty(result); + Assert.Equal(organization.Id, result.First().Id); + } + + [Theory, DatabaseData] + public async Task GetByVerifiedUserEmailDomainAsync_WithRevokedUser_ReturnsOrganization( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateRevokedTestOrganizationUserAsync(organization, user); + + var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); + + Assert.NotEmpty(result); + Assert.Equal(organization.Id, result.First().Id); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs index 67e2c1910b..52b1e7484b 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationRepositoryTests.cs @@ -8,254 +8,7 @@ namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories; public class OrganizationRepositoryTests { - [DatabaseTheory, DatabaseData] - public async Task GetByClaimedUserDomainAsync_WithVerifiedDomain_Success( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user1 = await userRepository.CreateAsync(new User - { - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user2 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@x-{domainName}", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user3 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@{domainName}.example.com", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL - Plan = "Test", // TODO: EF does not enforce this being NOT NULL - PrivateKey = "privatekey", - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetVerifiedDate(); - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user2.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user3.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - var user1Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user1.Id); - var user2Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user2.Id); - var user3Response = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user3.Id); - - Assert.NotEmpty(user1Response); - Assert.Equal(organization.Id, user1Response.First().Id); - Assert.Empty(user2Response); - Assert.Empty(user3Response); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithUnverifiedDomains_ReturnsEmpty( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user = await userRepository.CreateAsync(new User - { - Name = "Test User", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey", - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey", - }); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); - - Assert.Empty(result); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithMultipleVerifiedDomains_ReturnsAllMatchingOrganizations( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user = await userRepository.CreateAsync(new User - { - Name = "Test User", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization1 = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org 1 {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey1", - }); - - var organization2 = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org 2 {id}", - BillingEmail = user.Email, - Plan = "Test", - PrivateKey = "privatekey2", - }); - - var organizationDomain1 = new OrganizationDomain - { - OrganizationId = organization1.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain1.SetNextRunDate(12); - organizationDomain1.SetJobRunCount(); - organizationDomain1.SetVerifiedDate(); - await organizationDomainRepository.CreateAsync(organizationDomain1); - - var organizationDomain2 = new OrganizationDomain - { - OrganizationId = organization2.Id, - DomainName = domainName, - Txt = "btw+67890", - }; - organizationDomain2.SetNextRunDate(12); - organizationDomain2.SetJobRunCount(); - organizationDomain2.SetVerifiedDate(); - await organizationDomainRepository.CreateAsync(organizationDomain2); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization1.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey1", - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - OrganizationId = organization2.Id, - UserId = user.Id, - Status = OrganizationUserStatusType.Confirmed, - ResetPasswordKey = "resetpasswordkey2", - }); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(user.Id); - - Assert.Equal(2, result.Count); - Assert.Contains(result, org => org.Id == organization1.Id); - Assert.Contains(result, org => org.Id == organization2.Id); - } - - [DatabaseTheory, DatabaseData] - public async Task GetByVerifiedUserEmailDomainAsync_WithNonExistentUser_ReturnsEmpty( - IOrganizationRepository organizationRepository) - { - var nonExistentUserId = Guid.NewGuid(); - - var result = await organizationRepository.GetByVerifiedUserEmailDomainAsync(nonExistentUserId); - - Assert.Empty(result); - } - - - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetManyByIdsAsync_ExistingOrganizations_ReturnsOrganizations(IOrganizationRepository organizationRepository) { var email = "test@email.com"; @@ -287,7 +40,7 @@ public class OrganizationRepositoryTests await organizationRepository.DeleteAsync(organization2); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithUsersAndSponsorships_ReturnsCorrectCounts( IUserRepository userRepository, IOrganizationRepository organizationRepository, @@ -356,7 +109,7 @@ public class OrganizationRepositoryTests Assert.Equal(4, result.Total); // Total occupied seats } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithNoUsersOrSponsorships_ReturnsZero( IOrganizationRepository organizationRepository) { @@ -372,7 +125,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyRevokedUsers_ReturnsZero( IUserRepository userRepository, IOrganizationRepository organizationRepository, @@ -399,7 +152,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task GetOccupiedSeatCountByOrganizationIdAsync_WithOnlyExpiredSponsorships_ReturnsZero( IOrganizationRepository organizationRepository, IOrganizationSponsorshipRepository organizationSponsorshipRepository) @@ -424,7 +177,7 @@ public class OrganizationRepositoryTests Assert.Equal(0, result.Total); } - [DatabaseTheory, DatabaseData] + [Theory, DatabaseData] public async Task IncrementSeatCountAsync_IncrementsSeatCount(IOrganizationRepository organizationRepository) { var organization = await organizationRepository.CreateTestOrganizationAsync(); @@ -438,7 +191,7 @@ public class OrganizationRepositoryTests Assert.Equal(8, result.Seats); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task IncrementSeatCountAsync_GivenOrganizationHasNotChangedSeatCountBefore_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved( IOrganizationRepository sutRepository) { @@ -462,7 +215,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task IncrementSeatCountAsync_GivenOrganizationHasChangedSeatCountBeforeAndRecordExists_WhenUpdatingOrgSeats_ThenSubscriptionUpdateIsSaved( IOrganizationRepository sutRepository) { @@ -487,7 +240,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task GetOrganizationsForSubscriptionSyncAsync_GivenOrganizationHasChangedSeatCount_WhenGettingOrgsToUpdate_ThenReturnsOrgSubscriptionUpdate( IOrganizationRepository sutRepository) { @@ -510,7 +263,7 @@ public class OrganizationRepositoryTests await sutRepository.DeleteAsync(organization); } - [DatabaseData, DatabaseTheory] + [DatabaseData, Theory] public async Task UpdateSuccessfulOrganizationSyncStatusAsync_GivenOrganizationHasChangedSeatCount_WhenUpdatingStatus_ThenSuccessfullyUpdatesOrgSoItDoesntSync( IOrganizationRepository sutRepository) { diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs new file mode 100644 index 0000000000..6fa395751b --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/GetManyByOrganizationWithClaimedDomainsAsyncTests.cs @@ -0,0 +1,197 @@ +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.OrganizationUserRepository; + +public class GetManyByOrganizationWithClaimedDomainsAsyncTests +{ + [Theory, DatabaseData] + public async Task WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user1 = await userRepository.CreateAsync(new User + { + Name = "Test User 1", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user2 = await userRepository.CreateAsync(new User + { + Name = "Test User 2", + Email = $"test+{id}@x-{domainName}", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var user3 = await userRepository.CreateAsync(new User + { + Name = "Test User 3", + Email = $"test+{id}@{domainName}.example.com", // Different domain + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + var orgUser1 = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user1); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user2); + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user3); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + Assert.Single(result); + Assert.Equal(orgUser1.Id, result.Single().Id); + } + + [Theory, DatabaseData] + public async Task WithNoVerifiedDomain_ReturnsEmpty( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + // Create domain but do NOT verify it + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetNextRunDate(12); + // Note: NOT calling SetVerifiedDate() + await organizationDomainRepository.CreateAsync(organizationDomain); + + await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, user); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + Assert.Empty(result); + } + + /// + /// Tests an edge case where some invited users are created linked to a UserId. + /// This is defective behavior, but will take longer to fix - for now, we are defensive and expressly + /// exclude such users from the results without relying on the inner join only. + /// Invited-revoked users linked to a UserId remain intentionally unhandled for now as they have not caused + /// any issues to date and we want to minimize edge cases. + /// We will fix the underlying issue going forward: https://bitwarden.atlassian.net/browse/PM-22405 + /// + [Theory, DatabaseData] + public async Task WithVerifiedDomain_ExcludesInvitedUsers( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IOrganizationDomainRepository organizationDomainRepository) + { + var id = Guid.NewGuid(); + var domainName = $"{id}.example.com"; + + var invitedUser = await userRepository.CreateAsync(new User + { + Name = "Invited User", + Email = $"invited+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var confirmedUser = await userRepository.CreateAsync(new User + { + Name = "Confirmed User", + Email = $"confirmed+{id}@{domainName}", + ApiKey = "TEST", + SecurityStamp = "stamp", + Kdf = KdfType.PBKDF2_SHA256, + KdfIterations = 1, + KdfMemory = 2, + KdfParallelism = 3 + }); + + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var organizationDomain = new OrganizationDomain + { + OrganizationId = organization.Id, + DomainName = domainName, + Txt = "btw+12345", + }; + organizationDomain.SetVerifiedDate(); + organizationDomain.SetNextRunDate(12); + organizationDomain.SetJobRunCount(); + await organizationDomainRepository.CreateAsync(organizationDomain); + + // Create invited user with UserId set (edge case - should be excluded even with UserId linked) + var invitedOrgUser = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = invitedUser.Id, // Edge case: invited user with UserId set + Email = invitedUser.Email, + Status = OrganizationUserStatusType.Invited, + Type = OrganizationUserType.User + }); + + // Create confirmed user linked by UserId only (no Email field set) + var confirmedOrgUser = await organizationUserRepository.CreateConfirmedTestOrganizationUserAsync(organization, confirmedUser); + + var result = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); + + Assert.NotNull(result); + var claimedUser = Assert.Single(result); + Assert.Equal(confirmedOrgUser.Id, claimedUser.Id); + } +} diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs index 1c433d0e6e..b77406abf5 100644 --- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/OrganizationUserRepository/OrganizationUserRepositoryTests.cs @@ -599,136 +599,6 @@ public class OrganizationUserRepositoryTests Assert.Null(orgWithoutSsoDetails.SsoConfig); } - [DatabaseTheory, DatabaseData] - public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithVerifiedDomain_WithOneMatchingEmailDomain_ReturnsSingle( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - - var user1 = await userRepository.CreateAsync(new User - { - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user2 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@x-{domainName}", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var user3 = await userRepository.CreateAsync(new User - { - Name = "Test User 2", - Email = $"test+{id}@{domainName}.example.com", // Different domain - ApiKey = "TEST", - SecurityStamp = "stamp", - Kdf = KdfType.PBKDF2_SHA256, - KdfIterations = 1, - KdfMemory = 2, - KdfParallelism = 3 - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Name = $"Test Org {id}", - BillingEmail = user1.Email, // TODO: EF does not enforce this being NOT NULL - Plan = "Test", // TODO: EF does not enforce this being NOT NULL - PrivateKey = "privatekey", - UsePolicies = false, - UseSso = false, - UseKeyConnector = false, - UseScim = false, - UseGroups = false, - UseDirectory = false, - UseEvents = false, - UseTotp = false, - Use2fa = false, - UseApi = false, - UseResetPassword = false, - UseSecretsManager = false, - SelfHost = false, - UsersGetPremium = false, - UseCustomPermissions = false, - Enabled = true, - UsePasswordManager = false, - LimitCollectionCreation = false, - LimitCollectionDeletion = false, - LimitItemDeletion = false, - AllowAdminAccessToAllCollectionItems = false, - UseRiskInsights = false, - UseAdminSponsoredFamilies = false, - UsePhishingBlocker = false, - UseDisableSmAdsForUsers = false, - }); - - var organizationDomain = new OrganizationDomain - { - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - }; - organizationDomain.SetVerifiedDate(); - organizationDomain.SetNextRunDate(12); - organizationDomain.SetJobRunCount(); - await organizationDomainRepository.CreateAsync(organizationDomain); - - var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Owner, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user2.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user3.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - ResetPasswordKey = "resetpasswordkey1", - AccessSecretsManager = false - }); - - var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); - - Assert.NotNull(responseModel); - Assert.Single(responseModel); - Assert.Equal(orgUser1.Id, responseModel.Single().Id); - } - [DatabaseTheory, DatabaseData] public async Task CreateManyAsync_NoId_Works(IOrganizationRepository organizationRepository, IUserRepository userRepository, @@ -1237,70 +1107,6 @@ public class OrganizationUserRepositoryTests Assert.DoesNotContain(user1Result.Collections, c => c.Id == defaultUserCollection.Id); } - [DatabaseTheory, DatabaseData] - public async Task GetManyByOrganizationWithClaimedDomainsAsync_WithNoVerifiedDomain_ReturnsEmpty( - IUserRepository userRepository, - IOrganizationRepository organizationRepository, - IOrganizationUserRepository organizationUserRepository, - IOrganizationDomainRepository organizationDomainRepository) - { - var id = Guid.NewGuid(); - var domainName = $"{id}.example.com"; - var requestTime = DateTime.UtcNow; - - var user1 = await userRepository.CreateAsync(new User - { - Id = CoreHelpers.GenerateComb(), - Name = "Test User 1", - Email = $"test+{id}@{domainName}", - ApiKey = "TEST", - SecurityStamp = "stamp", - CreationDate = requestTime, - RevisionDate = requestTime, - AccountRevisionDate = requestTime - }); - - var organization = await organizationRepository.CreateAsync(new Organization - { - Id = CoreHelpers.GenerateComb(), - Name = $"Test Org {id}", - BillingEmail = user1.Email, - Plan = "Test", - Enabled = true, - CreationDate = requestTime, - RevisionDate = requestTime - }); - - // Create domain but do NOT verify it - var organizationDomain = new OrganizationDomain - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - DomainName = domainName, - Txt = "btw+12345", - CreationDate = requestTime - }; - organizationDomain.SetNextRunDate(12); - // Note: NOT calling SetVerifiedDate() - await organizationDomainRepository.CreateAsync(organizationDomain); - - await organizationUserRepository.CreateAsync(new OrganizationUser - { - Id = CoreHelpers.GenerateComb(), - OrganizationId = organization.Id, - UserId = user1.Id, - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Owner, - CreationDate = requestTime, - RevisionDate = requestTime - }); - - var responseModel = await organizationUserRepository.GetManyByOrganizationWithClaimedDomainsAsync(organization.Id); - - Assert.NotNull(responseModel); - Assert.Empty(responseModel); - } - [DatabaseTheory, DatabaseData] public async Task DeleteAsync_WithNullEmail_DoesNotSetDefaultUserCollectionEmail(IUserRepository userRepository, ICollectionRepository collectionRepository, diff --git a/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql b/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql new file mode 100644 index 0000000000..c7935db5e8 --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-13_00_Collection_CreateDefaultCollections.sql @@ -0,0 +1,70 @@ +-- Creates default user collections for organization users +-- Filters out existing default collections at database level +CREATE OR ALTER PROCEDURE [dbo].[Collection_CreateDefaultCollections] + @OrganizationId UNIQUEIDENTIFIER, + @DefaultCollectionName VARCHAR(MAX), + @OrganizationUserCollectionIds AS [dbo].[TwoGuidIdArray] READONLY -- OrganizationUserId, CollectionId +AS +BEGIN + SET NOCOUNT ON + + DECLARE @Now DATETIME2(7) = GETUTCDATE() + + -- Filter to only users who don't have default collections + SELECT ids.Id1, ids.Id2 + INTO #FilteredIds + FROM @OrganizationUserCollectionIds ids + WHERE NOT EXISTS ( + SELECT 1 + FROM [dbo].[CollectionUser] cu + INNER JOIN [dbo].[Collection] c ON c.Id = cu.CollectionId + WHERE c.OrganizationId = @OrganizationId + AND c.[Type] = 1 -- CollectionType.DefaultUserCollection + AND cu.OrganizationUserId = ids.Id1 + ); + + -- Insert collections only for users who don't have default collections yet + INSERT INTO [dbo].[Collection] + ( + [Id], + [OrganizationId], + [Name], + [CreationDate], + [RevisionDate], + [Type], + [ExternalId], + [DefaultUserCollectionEmail] + ) + SELECT + ids.Id2, -- CollectionId + @OrganizationId, + @DefaultCollectionName, + @Now, + @Now, + 1, -- CollectionType.DefaultUserCollection + NULL, + NULL + FROM + #FilteredIds ids; + + -- Insert collection user mappings + INSERT INTO [dbo].[CollectionUser] + ( + [CollectionId], + [OrganizationUserId], + [ReadOnly], + [HidePasswords], + [Manage] + ) + SELECT + ids.Id2, -- CollectionId + ids.Id1, -- OrganizationUserId + 0, -- ReadOnly = false + 0, -- HidePasswords = false + 1 -- Manage = true + FROM + #FilteredIds ids; + + DROP TABLE #FilteredIds; +END +GO diff --git a/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql b/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql new file mode 100644 index 0000000000..788fa02b7c --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-14_00_ExcludeInvitedUsersFromClaimedDomain.sql @@ -0,0 +1,24 @@ +CREATE OR ALTER PROCEDURE [dbo].[Organization_ReadByClaimedUserEmailDomain] + @UserId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + WITH CTE_User AS ( + SELECT + U.[Id], + SUBSTRING(U.Email, CHARINDEX('@', U.Email) + 1, LEN(U.Email)) AS EmailDomain + FROM dbo.[UserView] U + WHERE U.[Id] = @UserId + ) + SELECT O.* + FROM CTE_User CU + INNER JOIN dbo.[OrganizationUserView] OU ON CU.[Id] = OU.[UserId] + INNER JOIN dbo.[OrganizationView] O ON OU.[OrganizationId] = O.[Id] + INNER JOIN dbo.[OrganizationDomainView] OD ON OU.[OrganizationId] = OD.[OrganizationId] + WHERE OD.[VerifiedDate] IS NOT NULL + AND CU.EmailDomain = OD.[DomainName] + AND O.[Enabled] = 1 + AND OU.[Status] != 0 -- Exclude invited users +END +GO diff --git a/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql b/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql new file mode 100644 index 0000000000..b7be5fd0e0 --- /dev/null +++ b/util/Migrator/DbScripts/2026-01-14_01_ExcludeInvitedUsersFromClaimedDomains_V2.sql @@ -0,0 +1,29 @@ +CREATE OR ALTER PROCEDURE [dbo].[OrganizationUser_ReadByOrganizationIdWithClaimedDomains_V2] + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON; + + WITH OrgUsers AS ( + SELECT * + FROM [dbo].[OrganizationUserView] + WHERE [OrganizationId] = @OrganizationId + AND [Status] != 0 -- Exclude invited users + ), + UserDomains AS ( + SELECT U.[Id], U.[EmailDomain] + FROM [dbo].[UserEmailDomainView] U + WHERE EXISTS ( + SELECT 1 + FROM [dbo].[OrganizationDomainView] OD + WHERE OD.[OrganizationId] = @OrganizationId + AND OD.[VerifiedDate] IS NOT NULL + AND OD.[DomainName] = U.[EmailDomain] + ) + ) + SELECT OU.* + FROM OrgUsers OU + JOIN UserDomains UD ON OU.[UserId] = UD.[Id] + OPTION (RECOMPILE); +END +GO
+ - +
-

+

© 2025 Bitwarden Inc. 1 N. Calle Cesar Chavez, Suite 102, Santa Barbara, CA, USA

Always confirm you are on a trusted Bitwarden domain before logging in:
- bitwarden.com | - Learn why we include this + bitwarden.com | + Learn why we include this