diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs index 83ec244c47..2fbe6be5c6 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/ConfirmOrganizationUserCommand.cs @@ -11,6 +11,7 @@ using Bit.Core.Billing.Enums; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data; using Bit.Core.Platform.Push; using Bit.Core.Repositories; using Bit.Core.Services; @@ -82,7 +83,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand throw new BadRequestException(error); } - await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers: [orgUser], defaultUserCollectionName); + await CreateDefaultCollectionAsync(orgUser, defaultUserCollectionName); return orgUser; } @@ -97,9 +98,13 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand .Select(r => r.Item1) .ToList(); - if (confirmedOrganizationUsers.Count > 0) + if (confirmedOrganizationUsers.Count == 1) { - await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName); + await CreateDefaultCollectionAsync(confirmedOrganizationUsers.Single(), defaultUserCollectionName); + } + else if (confirmedOrganizationUsers.Count > 1) + { + await CreateManyDefaultCollectionsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName); } return result; @@ -245,14 +250,54 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand } /// - /// Handles the side effects of confirming an organization user. - /// Creates a default collection for the user if the organization - /// has the OrganizationDataOwnership policy enabled. + /// Creates a default collection for a single user if required by the Organization Data Ownership policy. + /// + /// The organization user who has just been confirmed. + /// The encrypted default user collection name. + private async Task CreateDefaultCollectionAsync(OrganizationUser organizationUser, string defaultUserCollectionName) + { + if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)) + { + return; + } + + // Skip if no collection name provided (backwards compatibility) + if (string.IsNullOrWhiteSpace(defaultUserCollectionName)) + { + return; + } + + var organizationDataOwnershipPolicy = + await _policyRequirementQuery.GetAsync(organizationUser.UserId!.Value); + if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId)) + { + return; + } + + var defaultCollection = new Collection + { + OrganizationId = organizationUser.OrganizationId, + Name = defaultUserCollectionName, + Type = CollectionType.DefaultUserCollection + }; + var collectionUser = new CollectionAccessSelection + { + Id = organizationUser.Id, + ReadOnly = false, + HidePasswords = false, + Manage = true + }; + + await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]); + } + + /// + /// Creates default collections for multiple users if required by the Organization Data Ownership policy. /// /// The organization ID. /// The confirmed organization users. /// The encrypted default user collection name. - private async Task HandleConfirmationSideEffectsAsync(Guid organizationId, + private async Task CreateManyDefaultCollectionsAsync(Guid organizationId, IEnumerable confirmedOrganizationUsers, string defaultUserCollectionName) { if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation)) @@ -266,7 +311,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand return; } - var policyEligibleOrganizationUserIds = await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId); + var policyEligibleOrganizationUserIds = + await _policyRequirementQuery.GetManyByOrganizationIdAsync(organizationId); var eligibleOrganizationUserIds = confirmedOrganizationUsers .Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id)) diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs index cb72a51850..28d6614dcb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyRequirements/OrganizationDataOwnershipPolicyRequirement.cs @@ -67,6 +67,11 @@ public class OrganizationDataOwnershipPolicyRequirement : IPolicyRequirement var noCollectionNeeded = new DefaultCollectionRequest(Guid.Empty, false); return noCollectionNeeded; } + + public bool RequiresDefaultCollectionOnConfirm(Guid organizationId) + { + return _policyDetails.Any(p => p.OrganizationId == organizationId); + } } public record DefaultCollectionRequest(Guid OrganizationUserId, bool ShouldCreateDefaultCollection) diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index ca3e52751c..f86147ca7d 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -63,5 +63,12 @@ public interface ICollectionRepository : IRepository Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable collectionIds, IEnumerable users, IEnumerable groups); - Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName); + /// + /// Creates default user collections for the specified organization users if they do not already have one. + /// + /// The Organization ID. + /// The Organization User IDs to create default collections for. + /// The encrypted string to use as the default collection name. + /// + Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName); } diff --git a/src/Core/Utilities/CoreHelpers.cs b/src/Core/Utilities/CoreHelpers.cs index 813eb6d1aa..5acdc63489 100644 --- a/src/Core/Utilities/CoreHelpers.cs +++ b/src/Core/Utilities/CoreHelpers.cs @@ -41,9 +41,12 @@ public static class CoreHelpers }; /// - /// Generate sequential Guid for Sql Server. - /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs + /// Generate a sequential Guid for Sql Server. This prevents SQL Server index fragmentation by incorporating timestamp + /// information for sequential ordering. This should be preferred to for any database IDs. /// + /// + /// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs + /// /// A comb Guid. public static Guid GenerateComb() => GenerateComb(Guid.NewGuid(), DateTime.UtcNow); diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs index 3610c1c484..5a743ba028 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -1,6 +1,7 @@ using System.Data; using Bit.Core.Entities; using Bit.Core.Vault.Entities; +using Bit.Infrastructure.Dapper.Utilities; using Microsoft.Data.SqlClient; namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers; @@ -8,11 +9,25 @@ namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers; public static class BulkResourceCreationService { private const string _defaultErrorMessage = "Must have at least one record for bulk creation."; - public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collectionUsers, string errorMessage = _defaultErrorMessage) + public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, + IEnumerable collectionUsers, string errorMessage = _defaultErrorMessage) { + // Offload some work from SQL Server by pre-sorting before insert. + // This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks. + var sortedCollectionUsers = collectionUsers + .OrderBySqlGuid(cu => cu.CollectionId) + .ThenBySqlGuid(cu => cu.OrganizationUserId) + .ToList(); + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); bulkCopy.DestinationTableName = "[dbo].[CollectionUser]"; - var dataTable = BuildCollectionsUsersTable(bulkCopy, collectionUsers, errorMessage); + bulkCopy.BatchSize = 500; + bulkCopy.BulkCopyTimeout = 120; + bulkCopy.EnableStreaming = true; + bulkCopy.ColumnOrderHints.Add("CollectionId", SortOrder.Ascending); + bulkCopy.ColumnOrderHints.Add("OrganizationUserId", SortOrder.Ascending); + + var dataTable = BuildCollectionsUsersTable(bulkCopy, sortedCollectionUsers, errorMessage); await bulkCopy.WriteToServerAsync(dataTable); } @@ -96,11 +111,21 @@ public static class BulkResourceCreationService return table; } - public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collections, string errorMessage = _defaultErrorMessage) + public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, + IEnumerable collections, string errorMessage = _defaultErrorMessage) { + // Offload some work from SQL Server by pre-sorting before insert. + // This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks. + var sortedCollections = collections.OrderBySqlGuid(c => c.Id).ToList(); + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); bulkCopy.DestinationTableName = "[dbo].[Collection]"; - var dataTable = BuildCollectionsTable(bulkCopy, collections, errorMessage); + bulkCopy.BatchSize = 500; + bulkCopy.BulkCopyTimeout = 120; + bulkCopy.EnableStreaming = true; + bulkCopy.ColumnOrderHints.Add("Id", SortOrder.Ascending); + + var dataTable = BuildCollectionsTable(bulkCopy, sortedCollections, errorMessage); await bulkCopy.WriteToServerAsync(dataTable); } diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index ad00ac7086..c2a59f75aa 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -6,6 +6,7 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Core.Utilities; using Bit.Infrastructure.Dapper.AdminConsole.Helpers; using Dapper; using Microsoft.Data.SqlClient; @@ -326,9 +327,10 @@ public class CollectionRepository : Repository, ICollectionRep } } - public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName) + public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { - if (!affectedOrgUserIds.Any()) + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) { return; } @@ -340,7 +342,7 @@ public class CollectionRepository : Repository, ICollectionRep { var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId); - var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection); + var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); @@ -393,7 +395,7 @@ public class CollectionRepository : Repository, ICollectionRep foreach (var orgUserId in missingDefaultCollectionUserIds) { - var collectionId = Guid.NewGuid(); + var collectionId = CoreHelpers.GenerateComb(); collections.Add(new Collection { diff --git a/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs b/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs new file mode 100644 index 0000000000..fc548e2ff0 --- /dev/null +++ b/src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs @@ -0,0 +1,26 @@ +using System.Data.SqlTypes; + +namespace Bit.Infrastructure.Dapper.Utilities; + +public static class SqlGuidHelpers +{ + /// + /// Sorts the source IEnumerable by the specified Guid property using the comparison logic. + /// This is required because MSSQL server compares (and therefore sorts) Guids differently to C#. + /// Ref: https://learn.microsoft.com/en-us/sql/connect/ado-net/sql/compare-guid-uniqueidentifier-values + /// + public static IOrderedEnumerable OrderBySqlGuid( + this IEnumerable source, + Func keySelector) + { + return source.OrderBy(x => new SqlGuid(keySelector(x))); + } + + /// + public static IOrderedEnumerable ThenBySqlGuid( + this IOrderedEnumerable source, + Func keySelector) + { + return source.ThenBy(x => new SqlGuid(keySelector(x))); + } +} diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 021b5bcf16..5aa156d1f8 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -2,6 +2,7 @@ using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Utilities; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using LinqToDB.EntityFrameworkCore; @@ -793,9 +794,10 @@ public class CollectionRepository : Repository affectedOrgUserIds, string defaultCollectionName) + public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) { - if (!affectedOrgUserIds.Any()) + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) { return; } @@ -804,8 +806,7 @@ public class CollectionRepository : Repository().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); + var policyDetails = new PolicyDetails + { + OrganizationId = organization.Id, + OrganizationUserId = orgUser.Id, + IsProvider = false, + OrganizationUserStatus = orgUser.Status, + OrganizationUserType = orgUser.Type, + PolicyType = PolicyType.OrganizationDataOwnership + }; sutProvider.GetDependency() - .GetManyByOrganizationIdAsync(organization.Id) - .Returns(new List { orgUser.Id }); + .GetAsync(orgUser.UserId!.Value) + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails])); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName); await sutProvider.GetDependency() .Received(1) - .UpsertDefaultCollectionsAsync( - organization.Id, - Arg.Is>(ids => ids.Contains(orgUser.Id)), - collectionName); + .CreateAsync( + Arg.Is(c => + c.Name == collectionName && + c.OrganizationId == organization.Id && + c.Type == CollectionType.DefaultUserCollection), + Arg.Any>(), + Arg.Is>(cu => + cu.Single().Id == orgUser.Id && + cu.Single().Manage)); } [Theory, BitAutoData] @@ -511,7 +526,7 @@ public class ConfirmOrganizationUserCommandTests [Theory, BitAutoData] public async Task ConfirmUserAsync_WithCreateDefaultLocationEnabled_WithOrganizationDataOwnershipPolicyNotApplicable_DoesNotCreateDefaultCollection( Organization org, OrganizationUser confirmingUser, - [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user, + [OrganizationUser(OrganizationUserStatusType.Accepted, OrganizationUserType.Owner)] OrganizationUser orgUser, User user, string key, string collectionName, SutProvider sutProvider) { org.PlanType = PlanType.EnterpriseAnnually; @@ -523,9 +538,18 @@ public class ConfirmOrganizationUserCommandTests sutProvider.GetDependency().GetManyAsync(default).ReturnsForAnyArgs(new[] { user }); sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true); + var policyDetails = new PolicyDetails + { + OrganizationId = org.Id, + OrganizationUserId = orgUser.Id, + IsProvider = false, + OrganizationUserStatus = orgUser.Status, + OrganizationUserType = orgUser.Type, + PolicyType = PolicyType.OrganizationDataOwnership + }; sutProvider.GetDependency() - .GetManyByOrganizationIdAsync(org.Id) - .Returns(new List { orgUser.UserId!.Value }); + .GetAsync(orgUser.UserId!.Value) + .Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails])); await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);