1
0
mirror of https://github.com/bitwarden/server synced 2026-01-27 14:53:21 +00:00

Revert unnecessary ef changes

This commit is contained in:
Thomas Rittson
2026-01-06 14:25:24 +10:00
parent 1f4fc9b017
commit 4c32e20fd7

View File

@@ -3,6 +3,7 @@ using Bit.Core.AdminConsole.OrganizationFeatures.Collections;
using Bit.Core.Enums;
using Bit.Core.Models.Data;
using Bit.Core.Repositories;
using Bit.Core.Utilities;
using Bit.Infrastructure.EntityFramework.Models;
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
using LinqToDB.EntityFrameworkCore;
@@ -796,8 +797,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
{
var organizationUserIdsHashSet = organizationUserIds.ToHashSet();
if (organizationUserIdsHashSet.Count == 0)
organizationUserIds = organizationUserIds.ToList();
if (!organizationUserIds.Any())
{
return;
}
@@ -805,43 +806,46 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
// Query for users who already have default collections
var existingOrgUserIds = dbContext.CollectionUsers
.Where(cu => organizationUserIdsHashSet.Contains(cu.OrganizationUserId))
.Where(cu => cu.Collection.Type == CollectionType.DefaultUserCollection)
.Where(cu => cu.Collection.OrganizationId == organizationId)
.Select(cu => cu.OrganizationUserId);
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
// Filter to only users who need collections
var filteredOrgUserIds = organizationUserIdsHashSet.Except(existingOrgUserIds).ToList();
if (filteredOrgUserIds.Count == 0)
var (collectionUsers, collections) = CollectionUtils.BuildDefaultUserCollections(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
{
return;
}
var (collections, collectionUsers) =
CollectionUtils.BuildDefaultUserCollections(organizationId, filteredOrgUserIds, defaultCollectionName);
await dbContext.BulkCopyAsync(collections);
await dbContext.BulkCopyAsync(collectionUsers);
await using var transaction = await dbContext.Database.BeginTransactionAsync();
try
{
await dbContext.BulkCopyAsync(Mapper.Map<IEnumerable<Collection>>(collections));
await dbContext.BulkCopyAsync(Mapper.Map<IEnumerable<CollectionUser>>(collectionUsers));
await transaction.CommitAsync();
}
catch
{
await transaction.RollbackAsync();
throw;
}
await dbContext.SaveChangesAsync();
}
public async Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName)
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId)
{
// EF uses the same bulk copy approach as the main method
await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);
var results = await dbContext.OrganizationUsers
.Where(ou => ou.OrganizationId == organizationId)
.Join(
dbContext.CollectionUsers,
ou => ou.Id,
cu => cu.OrganizationUserId,
(ou, cu) => new { ou, cu }
)
.Join(
dbContext.Collections,
temp => temp.cu.CollectionId,
c => c.Id,
(temp, c) => new { temp.ou, Collection = c }
)
.Where(x => x.Collection.Type == CollectionType.DefaultUserCollection)
.Select(x => x.ou.Id)
.ToListAsync();
return results.ToHashSet();
}
public Task CreateDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds,
string defaultCollectionName) =>
CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);
}