diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs
index 83ec244c47..2fbe6be5c6 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs
@@ -11,6 +11,7 @@ using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Exceptions;
+using Bit.Core.Models.Data;
using Bit.Core.Platform.Push;
using Bit.Core.Repositories;
using Bit.Core.Services;
@@ -82,7 +83,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
throw new BadRequestException(error);
}
- await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers: [orgUser], defaultUserCollectionName);
+ await CreateDefaultCollectionAsync(orgUser, defaultUserCollectionName);
return orgUser;
}
@@ -97,9 +98,13 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
.Select(r => r.Item1)
.ToList();
- if (confirmedOrganizationUsers.Count > 0)
+ if (confirmedOrganizationUsers.Count == 1)
{
- await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName);
+ await CreateDefaultCollectionAsync(confirmedOrganizationUsers.Single(), defaultUserCollectionName);
+ }
+ else if (confirmedOrganizationUsers.Count > 1)
+ {
+ await CreateManyDefaultCollectionsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName);
}
return result;
@@ -245,14 +250,54 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
}
///
- /// Handles the side effects of confirming an organization user.
- /// Creates a default collection for the user if the organization
- /// has the OrganizationDataOwnership policy enabled.
+ /// Creates a default collection for a single user if required by the Organization Data Ownership policy.
+ ///
+ /// The organization user who has just been confirmed.
+ /// The encrypted default user collection name.
+ private async Task CreateDefaultCollectionAsync(OrganizationUser organizationUser, string defaultUserCollectionName)
+ {
+ if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation))
+ {
+ return;
+ }
+
+ // Skip if no collection name provided (backwards compatibility)
+ if (string.IsNullOrWhiteSpace(defaultUserCollectionName))
+ {
+ return;
+ }
+
+ var organizationDataOwnershipPolicy =
+ await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value);
+ if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId))
+ {
+ return;
+ }
+
+ var defaultCollection = new Collection
+ {
+ OrganizationId = organizationUser.OrganizationId,
+ Name = defaultUserCollectionName,
+ Type = CollectionType.DefaultUserCollection
+ };
+ var collectionUser = new CollectionAccessSelection
+ {
+ Id = organizationUser.Id,
+ ReadOnly = false,
+ HidePasswords = false,
+ Manage = true
+ };
+
+ await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]);
+ }
+
+ ///
+ /// Creates default collections for multiple users if required by the Organization Data Ownership policy.
///
/// The organization ID.
/// The confirmed organization users.
/// The encrypted default user collection name.
- private async Task HandleConfirmationSideEffectsAsync(Guid organizationId,
+ private async Task CreateManyDefaultCollectionsAsync(Guid organizationId,
IEnumerable confirmedOrganizationUsers, string defaultUserCollectionName)
{
if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation))
@@ -266,7 +311,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
return;
}
- var policyEligibleOrganizationUserIds = await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId);
+ var policyEligibleOrganizationUserIds =
+ await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId);
var eligibleOrganizationUserIds = confirmedOrganizationUsers
.Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id))
diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs
index cb72a51850..28d6614dcb 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs
@@ -67,6 +67,11 @@ public class OrganizationDataOwnershipPolicyRequirement : IPolicyRequirement
var noCollectionNeeded = new DefaultCollectionRequest(Guid.Empty, false);
return noCollectionNeeded;
}
+
+ public bool RequiresDefaultCollectionOnConfirm(Guid organizationId)
+ {
+ return _policyDetails.Any(p => p.OrganizationId == organizationId);
+ }
}
public record DefaultCollectionRequest(Guid OrganizationUserId, bool ShouldCreateDefaultCollection)
diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs
index ca3e52751c..f86147ca7d 100644
--- a/src/Core/Repositories/ICollectionRepository.cs
+++ b/src/Core/Repositories/ICollectionRepository.cs
@@ -63,5 +63,12 @@ public interface ICollectionRepository : IRepository
Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable collectionIds,
IEnumerable users, IEnumerable groups);
- Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName);
+ ///
+ /// Creates default user collections for the specified organization users if they do not already have one.
+ ///
+ /// The Organization ID.
+ /// The Organization User IDs to create default collections for.
+ /// The encrypted string to use as the default collection name.
+ ///
+ Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName);
}
diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs
index 813eb6d1aa..5acdc63489 100644
--- a/src/Core/Utilities/CoreHelpers.cs
+++ b/src/Core/Utilities/CoreHelpers.cs
@@ -41,9 +41,12 @@ public static class CoreHelpers
};
///
- /// Generate sequential Guid for Sql Server.
- /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs
+ /// Generate a sequential Guid for Sql Server. This prevents SQL Server index fragmentation by incorporating timestamp
+ /// information for sequential ordering. This should be preferred to for any database IDs.
///
+ ///
+ /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs
+ ///
/// A comb Guid.
public static Guid GenerateComb()
=> GenerateComb(Guid.NewGuid(), DateTime.UtcNow);
diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs
index 3610c1c484..5a743ba028 100644
--- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs
+++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs
@@ -1,6 +1,7 @@
using System.Data;
using Bit.Core.Entities;
using Bit.Core.Vault.Entities;
+using Bit.Infrastructure.Dapper.Utilities;
using Microsoft.Data.SqlClient;
namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers;
@@ -8,11 +9,25 @@ namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers;
public static class BulkResourceCreationService
{
private const string _defaultErrorMessage = "Must have at least one record for bulk creation.";
- public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collectionUsers, string errorMessage = _defaultErrorMessage)
+ public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction,
+ IEnumerable collectionUsers, string errorMessage = _defaultErrorMessage)
{
+ // Offload some work from SQL Server by pre-sorting before insert.
+ // This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks.
+ var sortedCollectionUsers = collectionUsers
+ .OrderBySqlGuid(cu => cu.CollectionId)
+ .ThenBySqlGuid(cu => cu.OrganizationUserId)
+ .ToList();
+
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
bulkCopy.DestinationTableName = "[dbo].[CollectionUser]";
- var dataTable = BuildCollectionsUsersTable(bulkCopy, collectionUsers, errorMessage);
+ bulkCopy.BatchSize = 500;
+ bulkCopy.BulkCopyTimeout = 120;
+ bulkCopy.EnableStreaming = true;
+ bulkCopy.ColumnOrderHints.Add("CollectionId", SortOrder.Ascending);
+ bulkCopy.ColumnOrderHints.Add("OrganizationUserId", SortOrder.Ascending);
+
+ var dataTable = BuildCollectionsUsersTable(bulkCopy, sortedCollectionUsers, errorMessage);
await bulkCopy.WriteToServerAsync(dataTable);
}
@@ -96,11 +111,21 @@ public static class BulkResourceCreationService
return table;
}
- public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collections, string errorMessage = _defaultErrorMessage)
+ public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction,
+ IEnumerable collections, string errorMessage = _defaultErrorMessage)
{
+ // Offload some work from SQL Server by pre-sorting before insert.
+ // This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks.
+ var sortedCollections = collections.OrderBySqlGuid(c => c.Id).ToList();
+
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
bulkCopy.DestinationTableName = "[dbo].[Collection]";
- var dataTable = BuildCollectionsTable(bulkCopy, collections, errorMessage);
+ bulkCopy.BatchSize = 500;
+ bulkCopy.BulkCopyTimeout = 120;
+ bulkCopy.EnableStreaming = true;
+ bulkCopy.ColumnOrderHints.Add("Id", SortOrder.Ascending);
+
+ var dataTable = BuildCollectionsTable(bulkCopy, sortedCollections, errorMessage);
await bulkCopy.WriteToServerAsync(dataTable);
}
diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
index ad00ac7086..c2a59f75aa 100644
--- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
+++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
@@ -6,6 +6,7 @@ using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Settings;
+using Bit.Core.Utilities;
using Bit.Infrastructure.Dapper.AdminConsole.Helpers;
using Dapper;
using Microsoft.Data.SqlClient;
@@ -326,9 +327,10 @@ public class CollectionRepository : Repository, ICollectionRep
}
}
- public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName)
+ public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName)
{
- if (!affectedOrgUserIds.Any())
+ organizationUserIds = organizationUserIds.ToList();
+ if (!organizationUserIds.Any())
{
return;
}
@@ -340,7 +342,7 @@ public class CollectionRepository : Repository, ICollectionRep
{
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
- var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection);
+ var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
@@ -393,7 +395,7 @@ public class CollectionRepository : Repository, ICollectionRep
foreach (var orgUserId in missingDefaultCollectionUserIds)
{
- var collectionId = Guid.NewGuid();
+ var collectionId = CoreHelpers.GenerateComb();
collections.Add(new Collection
{
diff --git a/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs b/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs
new file mode 100644
index 0000000000..fc548e2ff0
--- /dev/null
+++ b/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs
@@ -0,0 +1,26 @@
+using System.Data.SqlTypes;
+
+namespace Bit.Infrastructure.Dapper.Utilities;
+
+public static class SqlGuidHelpers
+{
+ ///
+ /// Sorts the source IEnumerable by the specified Guid property using the comparison logic.
+ /// This is required because MSSQL server compares (and therefore sorts) Guids differently to C#.
+ /// Ref: https://learn.microsoft.com/en-us/sql/connect/ado-net/sql/compare-guid-uniqueidentifier-values
+ ///
+ public static IOrderedEnumerable OrderBySqlGuid(
+ this IEnumerable source,
+ Func keySelector)
+ {
+ return source.OrderBy(x => new SqlGuid(keySelector(x)));
+ }
+
+ ///
+ public static IOrderedEnumerable ThenBySqlGuid(
+ this IOrderedEnumerable source,
+ Func keySelector)
+ {
+ return source.ThenBy(x => new SqlGuid(keySelector(x)));
+ }
+}
diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
index 021b5bcf16..5aa156d1f8 100644
--- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
+++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
@@ -2,6 +2,7 @@
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
+using Bit.Core.Utilities;
using Bit.Infrastructure.EntityFramework.Models;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using LinqToDB.EntityFrameworkCore;
@@ -793,9 +794,10 @@ public class CollectionRepository : Repository affectedOrgUserIds, string defaultCollectionName)
+ public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName)
{
- if (!affectedOrgUserIds.Any())
+ organizationUserIds = organizationUserIds.ToList();
+ if (!organizationUserIds.Any())
{
return;
}
@@ -804,8 +806,7 @@ public class CollectionRepository : Repository().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true);
+ var policyDetails = new PolicyDetails
+ {
+ OrganizationId = organization.Id,
+ OrganizationUserId = orgUser.Id,
+ IsProvider = false,
+ OrganizationUserStatus = orgUser.Status,
+ OrganizationUserType = orgUser.Type,
+ PolicyType = PolicyType.OrganizationDataOwnership
+ };
sutProvider.GetDependency()
- .GetManyByOrganizationIdAsync(organization.Id)
- .Returns(new List { orgUser.Id });
+ .GetAsync(orgUser.UserId!.Value)
+ .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]));
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
await sutProvider.GetDependency()
.Received(1)
- .UpsertDefaultCollectionsAsync(
- organization.Id,
- Arg.Is>(ids => ids.Contains(orgUser.Id)),
- collectionName);
+ .CreateAsync(
+ Arg.Is(c =>
+ c.Name == collectionName &&
+ c.OrganizationId == organization.Id &&
+ c.Type == CollectionType.DefaultUserCollection),
+ Arg.Any>(),
+ Arg.Is>(cu =>
+ cu.Single().Id == orgUser.Id &&
+ cu.Single().Manage));
}
[Theory, BitAutoData]
@@ -511,7 +526,7 @@ public class ConfirmOrganizationUserCommandTests
[Theory, BitAutoData]
public async Task ConfirmUserAsync_WithCreateDefaultLocationEnabled_WithOrganizationDataOwnershipPolicyNotApplicable_DoesNotCreateDefaultCollection(
Organization org, OrganizationUser confirmingUser,
- [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user,
+ [OrganizationUser(OrganizationUserStatusType.Accepted, OrganizationUserType.Owner)] OrganizationUser orgUser, User user,
string key, string collectionName, SutProvider sutProvider)
{
org.PlanType = PlanType.EnterpriseAnnually;
@@ -523,9 +538,18 @@ public class ConfirmOrganizationUserCommandTests
sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { user });
sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true);
+ var policyDetails = new PolicyDetails
+ {
+ OrganizationId = org.Id,
+ OrganizationUserId = orgUser.Id,
+ IsProvider = false,
+ OrganizationUserStatus = orgUser.Status,
+ OrganizationUserType = orgUser.Type,
+ PolicyType = PolicyType.OrganizationDataOwnership
+ };
sutProvider.GetDependency()
- .GetManyByOrganizationIdAsync(org.Id)
- .Returns(new List { orgUser.UserId!.Value });
+ .GetAsync(orgUser.UserId!.Value)
+ .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails]));
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);