diff --git a/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs
index 6b2da70d3e..116992146f 100644
--- a/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs
+++ b/src/Core/AdminConsole/OrganizationFeatures/Collections/CollectionUtils.cs
@@ -13,8 +13,7 @@ public static class CollectionUtils
/// The IDs for organization users who need default collections.
/// The encrypted string to use as the default collection name.
/// A tuple containing the collections and collection users.
- public static (IEnumerable collections,
- IEnumerable collectionUsers)
+ public static (ICollection collections, ICollection collectionUsers)
BuildDefaultUserCollections(Guid organizationId, IEnumerable organizationUserIds,
string defaultCollectionName)
{
diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
index a78a699b10..4b8cd3d371 100644
--- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
+++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs
@@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using Bit.Core.AdminConsole.OrganizationFeatures.Collections;
using Bit.Core.Entities;
+using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Settings;
@@ -400,8 +401,64 @@ public class CollectionRepository : Repository, ICollectionRep
public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName)
{
- // Use the stored procedure approach which handles filtering internally
- await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);
+ organizationUserIds = organizationUserIds.ToList();
+ if (!organizationUserIds.Any())
+ {
+ return;
+ }
+
+ await using var connection = new SqlConnection(ConnectionString);
+ connection.Open();
+ await using var transaction = connection.BeginTransaction();
+ try
+ {
+ var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
+
+ var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
+
+ var (collections, collectionUsers) =
+ CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
+
+ if (!collectionUsers.Any() || !collections.Any())
+ {
+ return;
+ }
+
+ await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections);
+ await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers);
+
+ transaction.Commit();
+ }
+ catch
+ {
+ transaction.Rollback();
+ throw;
+ }
+ }
+
+ private async Task> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId)
+ {
+ const string sql = @"
+ SELECT
+ ou.Id AS OrganizationUserId
+ FROM
+ OrganizationUser ou
+ INNER JOIN
+ CollectionUser cu ON cu.OrganizationUserId = ou.Id
+ INNER JOIN
+ Collection c ON c.Id = cu.CollectionId
+ WHERE
+ ou.OrganizationId = @OrganizationId
+ AND c.Type = @CollectionType;
+ ";
+
+ var organizationUserIds = await connection.QueryAsync(
+ sql,
+ new { OrganizationId = organizationId, CollectionType = CollectionType.DefaultUserCollection },
+ transaction: transaction
+ );
+
+ return organizationUserIds.ToHashSet();
}
public class CollectionWithGroupsAndUsers : Collection
diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
index 6211f71873..8146e7c53d 100644
--- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
+++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs
@@ -797,8 +797,8 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName)
{
- organizationUserIds = organizationUserIds.ToList();
- if (!organizationUserIds.Any())
+ var organizationUserIdsHashSet = organizationUserIds.ToHashSet();
+ if (organizationUserIdsHashSet.Count == 0)
{
return;
}
@@ -807,18 +807,15 @@ public class CollectionRepository : Repository organizationUserIdsHashSet.Contains(cu.OrganizationUserId))
.Where(cu => cu.Collection.Type == CollectionType.DefaultUserCollection)
.Where(cu => cu.Collection.OrganizationId == organizationId)
- .Select(cu => cu.OrganizationUserId)
- .ToListAsync();
+ .Select(cu => cu.OrganizationUserId);
// Filter to only users who need collections
- var filteredOrgUserIds = organizationUserIds.Except(existingOrgUserIds).ToList();
-
- if (!filteredOrgUserIds.Any())
+ var filteredOrgUserIds = organizationUserIdsHashSet.Except(existingOrgUserIds).ToList();
+ if (filteredOrgUserIds.Count == 0)
{
return;
}
diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs
index c778b24fd4..283760dbec 100644
--- a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs
+++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs
@@ -6,10 +6,6 @@ namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.Collectio
public class CreateDefaultCollectionsTests
{
- ///
- /// Test that CreateDefaultCollectionsAsync successfully creates default collections for new users
- /// with correct permissions
- ///
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CreatesDefaultCollections_Success(
IUserRepository userRepository,
@@ -63,9 +59,6 @@ public class CreateDefaultCollectionsTests
Assert.True(orgUser2CollectionUser.Manage);
}
- ///
- /// Test that calling CreateDefaultCollectionsAsync multiple times does NOT create duplicates
- ///
[Theory, DatabaseData]
public async Task CreateDefaultCollectionsAsync_CalledMultipleTimesForSameOrganizationUser_DoesNotCreateDuplicates(
IUserRepository userRepository,