1
0
mirror of https://github.com/bitwarden/server synced 2026-01-26 14:23:21 +00:00

Revert Bulk implementation

This commit is contained in:
Thomas Rittson
2026-01-06 14:16:47 +10:00
parent 91c51fd47c
commit dc5c06ff83
4 changed files with 66 additions and 20 deletions

View File

@@ -13,8 +13,7 @@ public static class CollectionUtils
/// <param name="organizationUserIds">The IDs for organization users who need default collections.</param>
/// <param name="defaultCollectionName">The encrypted string to use as the default collection name.</param>
/// <returns>A tuple containing the collections and collection users.</returns>
public static (IEnumerable<Collection> collections,
IEnumerable<CollectionUser> collectionUsers)
public static (ICollection<Collection> collections, ICollection<CollectionUser> collectionUsers)
BuildDefaultUserCollections(Guid organizationId, IEnumerable<Guid> organizationUserIds,
string defaultCollectionName)
{

View File

@@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using Bit.Core.AdminConsole.OrganizationFeatures.Collections;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Settings;
@@ -400,8 +401,64 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
{
// Use the stored procedure approach which handles filtering internally
await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);
organizationUserIds = organizationUserIds.ToList();
if (!organizationUserIds.Any())
{
return;
}
await using var connection = new SqlConnection(ConnectionString);
connection.Open();
await using var transaction = connection.BeginTransaction();
try
{
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collections, collectionUsers) =
CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
{
return;
}
await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections);
await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers);
transaction.Commit();
}
catch
{
transaction.Rollback();
throw;
}
}
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId)
{
const string sql = @"
SELECT
ou.Id AS OrganizationUserId
FROM
OrganizationUser ou
INNER JOIN
CollectionUser cu ON cu.OrganizationUserId = ou.Id
INNER JOIN
Collection c ON c.Id = cu.CollectionId
WHERE
ou.OrganizationId = @OrganizationId
AND c.Type = @CollectionType;
";
var organizationUserIds = await connection.QueryAsync<Guid>(
sql,
new { OrganizationId = organizationId, CollectionType = CollectionType.DefaultUserCollection },
transaction: transaction
);
return organizationUserIds.ToHashSet();
}
public class CollectionWithGroupsAndUsers : Collection

View File

@@ -797,8 +797,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
{
organizationUserIds = organizationUserIds.ToList();
if (!organizationUserIds.Any())
var organizationUserIdsHashSet = organizationUserIds.ToHashSet();
if (organizationUserIdsHashSet.Count == 0)
{
return;
}
@@ -807,18 +807,15 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
var dbContext = GetDatabaseContext(scope);
// Query for users who already have default collections
var organizationUserIdsHashSet = organizationUserIds.ToHashSet();
var existingOrgUserIds = await dbContext.CollectionUsers
var existingOrgUserIds = dbContext.CollectionUsers
.Where(cu => organizationUserIdsHashSet.Contains(cu.OrganizationUserId))
.Where(cu => cu.Collection.Type == CollectionType.DefaultUserCollection)
.Where(cu => cu.Collection.OrganizationId == organizationId)
.Select(cu => cu.OrganizationUserId)
.ToListAsync();
.Select(cu => cu.OrganizationUserId);
// Filter to only users who need collections
var filteredOrgUserIds = organizationUserIds.Except(existingOrgUserIds).ToList();
if (!filteredOrgUserIds.Any())
var filteredOrgUserIds = organizationUserIdsHashSet.Except(existingOrgUserIds).ToList();
if (filteredOrgUserIds.Count == 0)
{
return;
}

View File

@@ -6,10 +6,6 @@ namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.Collectio
public class CreateDefaultCollectionsTests
{
/// <summary>
/// Test that CreateDefaultCollectionsAsync successfully creates default collections for new users
/// with correct permissions
/// </summary>
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CreatesDefaultCollections_Success(
IUserRepository userRepository,
@@ -63,9 +59,6 @@ public class CreateDefaultCollectionsTests
Assert.True(orgUser2CollectionUser.Manage);
}
/// <summary>
/// Test that calling CreateDefaultCollectionsAsync multiple times does NOT create duplicates
/// </summary>
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CalledMultipleTimesForSameOrganizationUser_DoesNotCreateDuplicates(
IUserRepository userRepository,