1
0
mirror of https://github.com/bitwarden/server synced 2026-02-18 10:23:27 +00:00

Move filtering out of db layer

This commit is contained in:
Thomas Rittson
2026-01-01 13:51:10 +10:00
parent a6fcee7094
commit 21d7276625
6 changed files with 196 additions and 51 deletions

View File

@@ -64,16 +64,26 @@ public class OrganizationDataOwnershipPolicyValidator(
var userOrgIds = requirements
.Select(requirement => requirement.GetDefaultCollectionRequestOnPolicyEnable(policyUpdate.OrganizationId))
.Where(request => request.ShouldCreateDefaultCollection)
.Select(request => request.OrganizationUserId);
.Select(request => request.OrganizationUserId)
.ToList();
if (!userOrgIds.Any())
{
return;
}
await collectionRepository.UpsertDefaultCollectionsBulkAsync(
// Filter out users who already have default collections
var existingSemaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(userOrgIds);
var usersNeedingDefaultCollections = userOrgIds.Except(existingSemaphores).ToList();
if (!usersNeedingDefaultCollections.Any())
{
return;
}
await collectionRepository.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
userOrgIds,
usersNeedingDefaultCollections,
defaultCollectionName);
}
}

View File

@@ -74,12 +74,17 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
/// <summary>
/// Creates default user collections for the specified organization users using bulk insert operations.
/// Gracefully skips users who already have a default collection for the organization.
/// Use this if you need to create collections for > ~1k users.
/// Throws an exception if any user already has a default collection for the organization.
/// </summary>
/// <param name="organizationId">The Organization ID.</param>
/// <param name="organizationUserIds">The Organization User IDs to create default collections for.</param>
/// <param name="defaultCollectionName">The encrypted string to use as the default collection name.</param>
Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName);
/// <remarks>
/// If any of the OrganizationUsers may already have default collections, the caller should first filter out these
/// users using GetDefaultCollectionSemaphoresAsync before calling this method.
/// </remarks>
Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName);
/// <summary>
/// Gets default collection semaphores for the given organizationUserIds.

View File

@@ -383,7 +383,7 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
commandType: CommandType.StoredProcedure);
}
public async Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
{
organizationUserIds = organizationUserIds.ToList();
if (!organizationUserIds.Any())
@@ -391,8 +391,7 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
return;
}
var orgUserIdWithDefaultCollection = await GetDefaultCollectionSemaphoresAsync(organizationUserIds);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, organizationUserIds, defaultCollectionName);
await using var connection = new SqlConnection(ConnectionString);
connection.Open();
@@ -400,15 +399,9 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
try
{
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
{
return;
}
// CRITICAL: Insert semaphore entries BEFORE collections
// TODO: this will result in a creation date of the semaphore AFTER that of the collection, which is weird
// Database will throw on duplicate primary key (OrganizationUserId)
var now = DateTime.UtcNow;
var semaphores = collectionUsers.Select(c => new DefaultCollectionSemaphore
{

View File

@@ -803,21 +803,13 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
return;
}
var orgUserIdWithDefaultCollection = await GetDefaultCollectionSemaphoresAsync(organizationUserIds);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, organizationUserIds, defaultCollectionName);
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
{
return;
}
// CRITICAL: Insert semaphore entries BEFORE collections
// TODO: this will result in a creation date of the semaphore AFTER that of the collection, which is weird
// Database will throw on duplicate primary key (OrganizationUserId)
var now = DateTime.UtcNow;
var semaphores = collectionUsers.Select(c => new DefaultCollectionSemaphore
{
@@ -832,7 +824,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await dbContext.SaveChangesAsync();
}
public async Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
{
// EF uses the same bulk copy approach as the main method
await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);

View File

@@ -40,7 +40,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -66,7 +66,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -92,7 +92,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
[Theory, BitAutoData]
@@ -118,7 +118,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceive()
.UpsertDefaultCollectionsBulkAsync(
.CreateDefaultCollectionsBulkAsync(
Arg.Any<Guid>(),
Arg.Any<IEnumerable<Guid>>(),
Arg.Any<string>());
@@ -198,6 +198,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList);
var collectionRepository = Substitute.For<ICollectionRepository>();
// Mock GetDefaultCollectionSemaphoresAsync to return empty set (no existing collections)
collectionRepository
.GetDefaultCollectionSemaphoresAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(new HashSet<Guid>());
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
@@ -207,12 +212,101 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsBulkAsync(
.GetDefaultCollectionSemaphoresAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3));
await collectionRepository
.Received(1)
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
}
[Theory]
[BitMemberAutoData(nameof(ShouldUpsertDefaultCollectionsTestCases))]
public async Task ExecuteSideEffectsAsync_FiltersOutUsersWithExistingCollections(
Policy postUpdatedPolicy,
Policy? previousPolicyState,
[PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate,
[OrganizationPolicyDetails(PolicyType.OrganizationDataOwnership)] IEnumerable<OrganizationPolicyDetails> orgPolicyDetails,
OrganizationDataOwnershipPolicyRequirementFactory factory)
{
// Arrange
var orgPolicyDetailsList = orgPolicyDetails.ToList();
foreach (var policyDetail in orgPolicyDetailsList)
{
policyDetail.OrganizationId = policyUpdate.OrganizationId;
}
var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList);
var collectionRepository = Substitute.For<ICollectionRepository>();
// Mock GetDefaultCollectionSemaphoresAsync to return one existing user
var existingUserId = orgPolicyDetailsList[0].OrganizationUserId;
collectionRepository
.GetDefaultCollectionSemaphoresAsync(Arg.Any<IEnumerable<Guid>>())
.Returns([existingUserId]);
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert - Should filter out the existing user
await collectionRepository
.Received(1)
.GetDefaultCollectionSemaphoresAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3));
await collectionRepository
.Received(1)
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 2 && !ids.Contains(existingUserId)),
_defaultUserCollectionName);
}
[Theory]
[BitMemberAutoData(nameof(ShouldUpsertDefaultCollectionsTestCases))]
public async Task ExecuteSideEffectsAsync_DoesNotCallRepository_WhenAllUsersHaveExistingCollections(
Policy postUpdatedPolicy,
Policy? previousPolicyState,
[PolicyUpdate(PolicyType.OrganizationDataOwnership)] PolicyUpdate policyUpdate,
[OrganizationPolicyDetails(PolicyType.OrganizationDataOwnership)] IEnumerable<OrganizationPolicyDetails> orgPolicyDetails,
OrganizationDataOwnershipPolicyRequirementFactory factory)
{
// Arrange
var orgPolicyDetailsList = orgPolicyDetails.ToList();
foreach (var policyDetail in orgPolicyDetailsList)
{
policyDetail.OrganizationId = policyUpdate.OrganizationId;
}
var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList);
var collectionRepository = Substitute.For<ICollectionRepository>();
// Mock GetDefaultCollectionSemaphoresAsync to return all users
var allUserIds = orgPolicyDetailsList.Select(p => p.OrganizationUserId).ToHashSet();
collectionRepository
.GetDefaultCollectionSemaphoresAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(allUserIds);
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
// Act
await sut.ExecuteSideEffectsAsync(policyRequest, postUpdatedPolicy, previousPolicyState);
// Assert - Should not call CreateDefaultCollectionsBulkAsync when all users already have collections
await collectionRepository
.Received(1)
.GetDefaultCollectionSemaphoresAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3));
await collectionRepository
.DidNotReceive()
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
private static IEnumerable<object?[]> WhenDefaultCollectionsDoesNotExistTestCases()
{
yield return [new OrganizationModelOwnershipPolicyModel(null)];
@@ -246,7 +340,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceive()
.UpsertDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
.CreateDefaultCollectionsBulkAsync(Arg.Any<Guid>(), Arg.Any<IEnumerable<Guid>>(), Arg.Any<string>());
}
private static IPolicyRepository ArrangePolicyRepository(IEnumerable<OrganizationPolicyDetails> policyDetails)
@@ -294,7 +388,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsBulkAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -320,7 +414,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsBulkAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -346,7 +440,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsBulkAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
[Theory, BitAutoData]
@@ -372,7 +466,7 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsBulkAsync(
.CreateDefaultCollectionsBulkAsync(
default,
default,
default);
@@ -403,6 +497,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests
var policyRepository = ArrangePolicyRepository(orgPolicyDetailsList);
var collectionRepository = Substitute.For<ICollectionRepository>();
// Mock GetDefaultCollectionSemaphoresAsync to return empty set (no existing collections)
collectionRepository
.GetDefaultCollectionSemaphoresAsync(Arg.Any<IEnumerable<Guid>>())
.Returns(new HashSet<Guid>());
var sut = ArrangeSut(factory, policyRepository, collectionRepository);
var policyRequest = new SavePolicyModel(policyUpdate, new OrganizationModelOwnershipPolicyModel(_defaultUserCollectionName));
@@ -412,7 +511,11 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await collectionRepository
.Received(1)
.UpsertDefaultCollectionsBulkAsync(
.GetDefaultCollectionSemaphoresAsync(Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3));
await collectionRepository
.Received(1)
.CreateDefaultCollectionsBulkAsync(
policyUpdate.OrganizationId,
Arg.Is<IEnumerable<Guid>>(ids => ids.Count() == 3),
_defaultUserCollectionName);
@@ -444,6 +547,6 @@ public class OrganizationDataOwnershipPolicyValidatorTests
// Assert
await sutProvider.GetDependency<ICollectionRepository>()
.DidNotReceiveWithAnyArgs()
.UpsertDefaultCollectionsBulkAsync(default, default, default);
.CreateDefaultCollectionsBulkAsync(default, default, default);
}
}

View File

@@ -6,10 +6,10 @@ using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
public class UpsertDefaultCollectionsBulkTests
public class CreateDefaultCollectionsBulkTests
{
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsBulkAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection(
public async Task CreateDefaultCollectionsBulkAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -27,7 +27,7 @@ public class UpsertDefaultCollectionsBulkTests
var defaultCollectionName = $"default-name-{organization.Id}";
// Act
await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
@@ -37,7 +37,7 @@ public class UpsertDefaultCollectionsBulkTests
}
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsBulkAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist(
public async Task CreateDefaultCollectionsBulkAsync_CreatesForNewUsersOnly_WhenCallerFiltersExisting(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -64,18 +64,20 @@ public class UpsertDefaultCollectionsBulkTests
var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers);
var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id).ToList();
// Act
await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Act - Caller filters out existing users (new pattern)
var existingSemaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(affectedOrgUserIds);
var usersNeedingCollections = affectedOrgUserIds.Except(existingSemaphores).ToList();
await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, usersNeedingCollections, defaultCollectionName);
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id);
// Assert - All users now have exactly one collection
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, affectedOrgUsers, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds);
await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers);
}
[Theory, DatabaseData]
public async Task UpsertDefaultCollectionsBulkAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne(
public async Task CreateDefaultCollectionsBulkAsync_ThrowsException_WhenUsersAlreadyHaveOne(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
@@ -94,21 +96,61 @@ public class UpsertDefaultCollectionsBulkTests
await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers);
// Act
await collectionRepository.UpsertDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
// Act - Try to create again, should throw database constraint exception
await Assert.ThrowsAnyAsync<Exception>(() =>
collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, affectedOrgUserIds, defaultCollectionName));
// Assert
// Assert - Original collections should remain unchanged
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds);
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
}
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsBulkAsync_ThrowsException_WhenDuplicatesNotFiltered(
IOrganizationRepository organizationRepository,
IUserRepository userRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionRepository collectionRepository)
{
// Arrange
var organization = await organizationRepository.CreateTestOrganizationAsync();
var existingUser = await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization);
var newUser = await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization);
var defaultCollectionName = $"default-name-{organization.Id}";
// Create collection for existing user
await collectionRepository.CreateDefaultCollectionsBulkAsync(organization.Id, [existingUser.Id], defaultCollectionName);
// Act - Try to create for both without filtering (incorrect usage)
await Assert.ThrowsAnyAsync<Exception>(() =>
collectionRepository.CreateDefaultCollectionsBulkAsync(
organization.Id,
[existingUser.Id, newUser.Id],
defaultCollectionName));
// Assert - Verify existing user still has collection
var existingUserCollections = await collectionRepository.GetManyByUserIdAsync(existingUser.UserId!.Value);
var existingUserDefaultCollection = existingUserCollections
.SingleOrDefault(c => c.OrganizationId == organization.Id && c.Type == CollectionType.DefaultUserCollection);
Assert.NotNull(existingUserDefaultCollection);
// Verify new user does NOT have collection (transaction rolled back)
var newUserCollections = await collectionRepository.GetManyByUserIdAsync(newUser.UserId!.Value);
var newUserDefaultCollection = newUserCollections
.FirstOrDefault(c => c.OrganizationId == organization.Id && c.Type == CollectionType.DefaultUserCollection);
Assert.Null(newUserDefaultCollection);
await CleanupAsync(organizationRepository, userRepository, organization, [existingUser, newUser]);
}
private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository,
Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName,
OrganizationUser[] resultOrganizationUsers)
{
await collectionRepository.UpsertDefaultCollectionsBulkAsync(organizationId, affectedOrgUserIds, defaultCollectionName);
await collectionRepository.CreateDefaultCollectionsBulkAsync(organizationId, affectedOrgUserIds, defaultCollectionName);
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId);
}