mirror of
https://github.com/bitwarden/server
synced 2025-12-18 17:23:28 +00:00
[PM-25138] Reduce db locking when creating default collections (#6308)
* Use single method for default collection creation * Use GenerateComb to create sequential guids * Pre-sort data for SqlBulkCopy * Add SqlBulkCopy options per dbops recommendations
This commit is contained in:
@@ -11,6 +11,7 @@ using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Services;
|
||||
@@ -82,7 +83,7 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
|
||||
throw new BadRequestException(error);
|
||||
}
|
||||
|
||||
await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers: [orgUser], defaultUserCollectionName);
|
||||
await CreateDefaultCollectionAsync(orgUser, defaultUserCollectionName);
|
||||
|
||||
return orgUser;
|
||||
}
|
||||
@@ -97,9 +98,13 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
|
||||
.Select(r => r.Item1)
|
||||
.ToList();
|
||||
|
||||
if (confirmedOrganizationUsers.Count > 0)
|
||||
if (confirmedOrganizationUsers.Count == 1)
|
||||
{
|
||||
await HandleConfirmationSideEffectsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName);
|
||||
await CreateDefaultCollectionAsync(confirmedOrganizationUsers.Single(), defaultUserCollectionName);
|
||||
}
|
||||
else if (confirmedOrganizationUsers.Count > 1)
|
||||
{
|
||||
await CreateManyDefaultCollectionsAsync(organizationId, confirmedOrganizationUsers, defaultUserCollectionName);
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -245,14 +250,54 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Handles the side effects of confirming an organization user.
|
||||
/// Creates a default collection for the user if the organization
|
||||
/// has the OrganizationDataOwnership policy enabled.
|
||||
/// Creates a default collection for a single user if required by the Organization Data Ownership policy.
|
||||
/// </summary>
|
||||
/// <param name="organizationUser">The organization user who has just been confirmed.</param>
|
||||
/// <param name="defaultUserCollectionName">The encrypted default user collection name.</param>
|
||||
private async Task CreateDefaultCollectionAsync(OrganizationUser organizationUser, string defaultUserCollectionName)
|
||||
{
|
||||
if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip if no collection name provided (backwards compatibility)
|
||||
if (string.IsNullOrWhiteSpace(defaultUserCollectionName))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var organizationDataOwnershipPolicy =
|
||||
await _policyRequirementQuery.GetAsync<OrganizationDataOwnershipPolicyRequirement>(organizationUser.UserId!.Value);
|
||||
if (!organizationDataOwnershipPolicy.RequiresDefaultCollectionOnConfirm(organizationUser.OrganizationId))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
var defaultCollection = new Collection
|
||||
{
|
||||
OrganizationId = organizationUser.OrganizationId,
|
||||
Name = defaultUserCollectionName,
|
||||
Type = CollectionType.DefaultUserCollection
|
||||
};
|
||||
var collectionUser = new CollectionAccessSelection
|
||||
{
|
||||
Id = organizationUser.Id,
|
||||
ReadOnly = false,
|
||||
HidePasswords = false,
|
||||
Manage = true
|
||||
};
|
||||
|
||||
await _collectionRepository.CreateAsync(defaultCollection, groups: null, users: [collectionUser]);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates default collections for multiple users if required by the Organization Data Ownership policy.
|
||||
/// </summary>
|
||||
/// <param name="organizationId">The organization ID.</param>
|
||||
/// <param name="confirmedOrganizationUsers">The confirmed organization users.</param>
|
||||
/// <param name="defaultUserCollectionName">The encrypted default user collection name.</param>
|
||||
private async Task HandleConfirmationSideEffectsAsync(Guid organizationId,
|
||||
private async Task CreateManyDefaultCollectionsAsync(Guid organizationId,
|
||||
IEnumerable<OrganizationUser> confirmedOrganizationUsers, string defaultUserCollectionName)
|
||||
{
|
||||
if (!_featureService.IsEnabled(FeatureFlagKeys.CreateDefaultLocation))
|
||||
@@ -266,7 +311,8 @@ public class ConfirmOrganizationUserCommand : IConfirmOrganizationUserCommand
|
||||
return;
|
||||
}
|
||||
|
||||
var policyEligibleOrganizationUserIds = await _policyRequirementQuery.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organizationId);
|
||||
var policyEligibleOrganizationUserIds =
|
||||
await _policyRequirementQuery.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organizationId);
|
||||
|
||||
var eligibleOrganizationUserIds = confirmedOrganizationUsers
|
||||
.Where(ou => policyEligibleOrganizationUserIds.Contains(ou.Id))
|
||||
|
||||
@@ -67,6 +67,11 @@ public class OrganizationDataOwnershipPolicyRequirement : IPolicyRequirement
|
||||
var noCollectionNeeded = new DefaultCollectionRequest(Guid.Empty, false);
|
||||
return noCollectionNeeded;
|
||||
}
|
||||
|
||||
public bool RequiresDefaultCollectionOnConfirm(Guid organizationId)
|
||||
{
|
||||
return _policyDetails.Any(p => p.OrganizationId == organizationId);
|
||||
}
|
||||
}
|
||||
|
||||
public record DefaultCollectionRequest(Guid OrganizationUserId, bool ShouldCreateDefaultCollection)
|
||||
|
||||
@@ -63,5 +63,12 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
|
||||
Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable<Guid> collectionIds,
|
||||
IEnumerable<CollectionAccessSelection> users, IEnumerable<CollectionAccessSelection> groups);
|
||||
|
||||
Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName);
|
||||
/// <summary>
|
||||
/// Creates default user collections for the specified organization users if they do not already have one.
|
||||
/// </summary>
|
||||
/// <param name="organizationId">The Organization ID.</param>
|
||||
/// <param name="organizationUserIds">The Organization User IDs to create default collections for.</param>
|
||||
/// <param name="defaultCollectionName">The encrypted string to use as the default collection name.</param>
|
||||
/// <returns></returns>
|
||||
Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName);
|
||||
}
|
||||
|
||||
@@ -41,9 +41,12 @@ public static class CoreHelpers
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Generate sequential Guid for Sql Server.
|
||||
/// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs
|
||||
/// Generate a sequential Guid for Sql Server. This prevents SQL Server index fragmentation by incorporating timestamp
|
||||
/// information for sequential ordering. This should be preferred to <see cref="Guid.NewGuid"/> for any database IDs.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// ref: https://github.com/nhibernate/nhibernate-core/blob/master/src/NHibernate/Id/GuidCombGenerator.cs
|
||||
/// </remarks>
|
||||
/// <returns>A comb Guid.</returns>
|
||||
public static Guid GenerateComb()
|
||||
=> GenerateComb(Guid.NewGuid(), DateTime.UtcNow);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
using System.Data;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Vault.Entities;
|
||||
using Bit.Infrastructure.Dapper.Utilities;
|
||||
using Microsoft.Data.SqlClient;
|
||||
|
||||
namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers;
|
||||
@@ -8,11 +9,25 @@ namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers;
|
||||
public static class BulkResourceCreationService
|
||||
{
|
||||
private const string _defaultErrorMessage = "Must have at least one record for bulk creation.";
|
||||
public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable<CollectionUser> collectionUsers, string errorMessage = _defaultErrorMessage)
|
||||
public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction,
|
||||
IEnumerable<CollectionUser> collectionUsers, string errorMessage = _defaultErrorMessage)
|
||||
{
|
||||
// Offload some work from SQL Server by pre-sorting before insert.
|
||||
// This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks.
|
||||
var sortedCollectionUsers = collectionUsers
|
||||
.OrderBySqlGuid(cu => cu.CollectionId)
|
||||
.ThenBySqlGuid(cu => cu.OrganizationUserId)
|
||||
.ToList();
|
||||
|
||||
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
|
||||
bulkCopy.DestinationTableName = "[dbo].[CollectionUser]";
|
||||
var dataTable = BuildCollectionsUsersTable(bulkCopy, collectionUsers, errorMessage);
|
||||
bulkCopy.BatchSize = 500;
|
||||
bulkCopy.BulkCopyTimeout = 120;
|
||||
bulkCopy.EnableStreaming = true;
|
||||
bulkCopy.ColumnOrderHints.Add("CollectionId", SortOrder.Ascending);
|
||||
bulkCopy.ColumnOrderHints.Add("OrganizationUserId", SortOrder.Ascending);
|
||||
|
||||
var dataTable = BuildCollectionsUsersTable(bulkCopy, sortedCollectionUsers, errorMessage);
|
||||
await bulkCopy.WriteToServerAsync(dataTable);
|
||||
}
|
||||
|
||||
@@ -96,11 +111,21 @@ public static class BulkResourceCreationService
|
||||
return table;
|
||||
}
|
||||
|
||||
public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable<Collection> collections, string errorMessage = _defaultErrorMessage)
|
||||
public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction,
|
||||
IEnumerable<Collection> collections, string errorMessage = _defaultErrorMessage)
|
||||
{
|
||||
// Offload some work from SQL Server by pre-sorting before insert.
|
||||
// This lets us use the SqlBulkCopy.ColumnOrderHints to improve performance and reduce deadlocks.
|
||||
var sortedCollections = collections.OrderBySqlGuid(c => c.Id).ToList();
|
||||
|
||||
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
|
||||
bulkCopy.DestinationTableName = "[dbo].[Collection]";
|
||||
var dataTable = BuildCollectionsTable(bulkCopy, collections, errorMessage);
|
||||
bulkCopy.BatchSize = 500;
|
||||
bulkCopy.BulkCopyTimeout = 120;
|
||||
bulkCopy.EnableStreaming = true;
|
||||
bulkCopy.ColumnOrderHints.Add("Id", SortOrder.Ascending);
|
||||
|
||||
var dataTable = BuildCollectionsTable(bulkCopy, sortedCollections, errorMessage);
|
||||
await bulkCopy.WriteToServerAsync(dataTable);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ using Bit.Core.Enums;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Settings;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Infrastructure.Dapper.AdminConsole.Helpers;
|
||||
using Dapper;
|
||||
using Microsoft.Data.SqlClient;
|
||||
@@ -326,9 +327,10 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
}
|
||||
}
|
||||
|
||||
public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName)
|
||||
public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
|
||||
{
|
||||
if (!affectedOrgUserIds.Any())
|
||||
organizationUserIds = organizationUserIds.ToList();
|
||||
if (!organizationUserIds.Any())
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -340,7 +342,7 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
{
|
||||
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
|
||||
|
||||
var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection);
|
||||
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
|
||||
|
||||
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
|
||||
|
||||
@@ -393,7 +395,7 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
||||
|
||||
foreach (var orgUserId in missingDefaultCollectionUserIds)
|
||||
{
|
||||
var collectionId = Guid.NewGuid();
|
||||
var collectionId = CoreHelpers.GenerateComb();
|
||||
|
||||
collections.Add(new Collection
|
||||
{
|
||||
|
||||
26
src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs
Normal file
26
src/Infrastructure.Dapper/Utilities/SqlGuidHelpers.cs
Normal file
@@ -0,0 +1,26 @@
|
||||
using System.Data.SqlTypes;
|
||||
|
||||
namespace Bit.Infrastructure.Dapper.Utilities;
|
||||
|
||||
public static class SqlGuidHelpers
|
||||
{
|
||||
/// <summary>
|
||||
/// Sorts the source IEnumerable by the specified Guid property using the <see cref="SqlGuid"/> comparison logic.
|
||||
/// This is required because MSSQL server compares (and therefore sorts) Guids differently to C#.
|
||||
/// Ref: https://learn.microsoft.com/en-us/sql/connect/ado-net/sql/compare-guid-uniqueidentifier-values
|
||||
/// </summary>
|
||||
public static IOrderedEnumerable<T> OrderBySqlGuid<T>(
|
||||
this IEnumerable<T> source,
|
||||
Func<T, Guid> keySelector)
|
||||
{
|
||||
return source.OrderBy(x => new SqlGuid(keySelector(x)));
|
||||
}
|
||||
|
||||
/// <inheritdoc cref="OrderBySqlGuid"/>
|
||||
public static IOrderedEnumerable<T> ThenBySqlGuid<T>(
|
||||
this IOrderedEnumerable<T> source,
|
||||
Func<T, Guid> keySelector)
|
||||
{
|
||||
return source.ThenBy(x => new SqlGuid(keySelector(x)));
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Repositories;
|
||||
using Bit.Core.Utilities;
|
||||
using Bit.Infrastructure.EntityFramework.Models;
|
||||
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
|
||||
using LinqToDB.EntityFrameworkCore;
|
||||
@@ -793,9 +794,10 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
// SaveChangesAsync is expected to be called outside this method
|
||||
}
|
||||
|
||||
public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName)
|
||||
public async Task UpsertDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
|
||||
{
|
||||
if (!affectedOrgUserIds.Any())
|
||||
organizationUserIds = organizationUserIds.ToList();
|
||||
if (!organizationUserIds.Any())
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -804,8 +806,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
var dbContext = GetDatabaseContext(scope);
|
||||
|
||||
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId);
|
||||
|
||||
var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection);
|
||||
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
|
||||
|
||||
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
|
||||
|
||||
@@ -850,7 +851,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
||||
|
||||
foreach (var orgUserId in missingDefaultCollectionUserIds)
|
||||
{
|
||||
var collectionId = Guid.NewGuid();
|
||||
var collectionId = CoreHelpers.GenerateComb();
|
||||
|
||||
collections.Add(new Collection
|
||||
{
|
||||
|
||||
@@ -10,6 +10,7 @@ using Bit.Core.Billing.Enums;
|
||||
using Bit.Core.Entities;
|
||||
using Bit.Core.Enums;
|
||||
using Bit.Core.Exceptions;
|
||||
using Bit.Core.Models.Data;
|
||||
using Bit.Core.Models.Data.Organizations.OrganizationUsers;
|
||||
using Bit.Core.Platform.Push;
|
||||
using Bit.Core.Repositories;
|
||||
@@ -471,18 +472,32 @@ public class ConfirmOrganizationUserCommandTests
|
||||
|
||||
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true);
|
||||
|
||||
var policyDetails = new PolicyDetails
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
OrganizationUserId = orgUser.Id,
|
||||
IsProvider = false,
|
||||
OrganizationUserStatus = orgUser.Status,
|
||||
OrganizationUserType = orgUser.Type,
|
||||
PolicyType = PolicyType.OrganizationDataOwnership
|
||||
};
|
||||
sutProvider.GetDependency<IPolicyRequirementQuery>()
|
||||
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(organization.Id)
|
||||
.Returns(new List<Guid> { orgUser.Id });
|
||||
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
|
||||
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Enabled, [policyDetails]));
|
||||
|
||||
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
|
||||
|
||||
await sutProvider.GetDependency<ICollectionRepository>()
|
||||
.Received(1)
|
||||
.UpsertDefaultCollectionsAsync(
|
||||
organization.Id,
|
||||
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(orgUser.Id)),
|
||||
collectionName);
|
||||
.CreateAsync(
|
||||
Arg.Is<Collection>(c =>
|
||||
c.Name == collectionName &&
|
||||
c.OrganizationId == organization.Id &&
|
||||
c.Type == CollectionType.DefaultUserCollection),
|
||||
Arg.Any<IEnumerable<CollectionAccessSelection>>(),
|
||||
Arg.Is<IEnumerable<CollectionAccessSelection>>(cu =>
|
||||
cu.Single().Id == orgUser.Id &&
|
||||
cu.Single().Manage));
|
||||
}
|
||||
|
||||
[Theory, BitAutoData]
|
||||
@@ -511,7 +526,7 @@ public class ConfirmOrganizationUserCommandTests
|
||||
[Theory, BitAutoData]
|
||||
public async Task ConfirmUserAsync_WithCreateDefaultLocationEnabled_WithOrganizationDataOwnershipPolicyNotApplicable_DoesNotCreateDefaultCollection(
|
||||
Organization org, OrganizationUser confirmingUser,
|
||||
[OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser orgUser, User user,
|
||||
[OrganizationUser(OrganizationUserStatusType.Accepted, OrganizationUserType.Owner)] OrganizationUser orgUser, User user,
|
||||
string key, string collectionName, SutProvider<ConfirmOrganizationUserCommand> sutProvider)
|
||||
{
|
||||
org.PlanType = PlanType.EnterpriseAnnually;
|
||||
@@ -523,9 +538,18 @@ public class ConfirmOrganizationUserCommandTests
|
||||
sutProvider.GetDependency<IUserRepository>().GetManyAsync(default).ReturnsForAnyArgs(new[] { user });
|
||||
sutProvider.GetDependency<IFeatureService>().IsEnabled(FeatureFlagKeys.CreateDefaultLocation).Returns(true);
|
||||
|
||||
var policyDetails = new PolicyDetails
|
||||
{
|
||||
OrganizationId = org.Id,
|
||||
OrganizationUserId = orgUser.Id,
|
||||
IsProvider = false,
|
||||
OrganizationUserStatus = orgUser.Status,
|
||||
OrganizationUserType = orgUser.Type,
|
||||
PolicyType = PolicyType.OrganizationDataOwnership
|
||||
};
|
||||
sutProvider.GetDependency<IPolicyRequirementQuery>()
|
||||
.GetManyByOrganizationIdAsync<OrganizationDataOwnershipPolicyRequirement>(org.Id)
|
||||
.Returns(new List<Guid> { orgUser.UserId!.Value });
|
||||
.GetAsync<OrganizationDataOwnershipPolicyRequirement>(orgUser.UserId!.Value)
|
||||
.Returns(new OrganizationDataOwnershipPolicyRequirement(OrganizationDataOwnershipState.Disabled, [policyDetails]));
|
||||
|
||||
await sutProvider.Sut.ConfirmUserAsync(orgUser.OrganizationId, orgUser.Id, key, confirmingUser.Id, collectionName);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user