diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index d8071c35bd..197caa5f46 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Collections; using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; +using Bit.Core.Utilities; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; using LinqToDB.EntityFrameworkCore; @@ -796,8 +797,8 @@ public class CollectionRepository : Repository organizationUserIds, string defaultCollectionName) { - var organizationUserIdsHashSet = organizationUserIds.ToHashSet(); - if (organizationUserIdsHashSet.Count == 0) + organizationUserIds = organizationUserIds.ToList(); + if (!organizationUserIds.Any()) { return; } @@ -805,43 +806,46 @@ public class CollectionRepository : Repository organizationUserIdsHashSet.Contains(cu.OrganizationUserId)) - .Where(cu => cu.Collection.Type == CollectionType.DefaultUserCollection) - .Where(cu => cu.Collection.OrganizationId == organizationId) - .Select(cu => cu.OrganizationUserId); + var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId); + var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection); - // Filter to only users who need collections - var filteredOrgUserIds = organizationUserIdsHashSet.Except(existingOrgUserIds).ToList(); - if (filteredOrgUserIds.Count == 0) + var (collectionUsers, collections) = CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + + if (!collectionUsers.Any() || !collections.Any()) { return; } - var (collections, collectionUsers) = - CollectionUtils.BuildDefaultUserCollections(organizationId, filteredOrgUserIds, defaultCollectionName); + await dbContext.BulkCopyAsync(collections); + await dbContext.BulkCopyAsync(collectionUsers); - await using var transaction = await dbContext.Database.BeginTransactionAsync(); - - try - { - await dbContext.BulkCopyAsync(Mapper.Map>(collections)); - await dbContext.BulkCopyAsync(Mapper.Map>(collectionUsers)); - - await transaction.CommitAsync(); - } - catch - { - await transaction.RollbackAsync(); - throw; - } + await dbContext.SaveChangesAsync(); } - public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, string defaultCollectionName) + private async Task> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId) { - // EF uses the same bulk copy approach as the main method - await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); + var results = await dbContext.OrganizationUsers + .Where(ou => ou.OrganizationId == organizationId) + .Join( + dbContext.CollectionUsers, + ou => ou.Id, + cu => cu.OrganizationUserId, + (ou, cu) => new { ou, cu } + ) + .Join( + dbContext.Collections, + temp => temp.cu.CollectionId, + c => c.Id, + (temp, c) => new { temp.ou, Collection = c } + ) + .Where(x => x.Collection.Type == CollectionType.DefaultUserCollection) + .Select(x => x.ou.Id) + .ToListAsync(); + + return results.ToHashSet(); } + public Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable organizationUserIds, + string defaultCollectionName) => + CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName); }