diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs index 70f4401eed..3f6565d27b 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationUserRepository.cs @@ -309,22 +309,42 @@ public class OrganizationUserRepository : Repository collections) + public async Task ReplaceAsync(Core.Entities.OrganizationUser obj, IEnumerable requestedCollections) { await base.ReplaceAsync(obj); using (var scope = ServiceScopeFactory.CreateScope()) { var dbContext = GetDatabaseContext(scope); - var procedure = new OrganizationUserUpdateWithCollectionsQuery(obj, collections); + var existingCollectionUsers = await dbContext.CollectionUsers + .Where(cu => cu.OrganizationUserId == obj.Id) + .ToListAsync(); - var update = procedure.Update.Run(dbContext); - dbContext.UpdateRange(await update.ToListAsync()); + foreach (var requestedCollection in requestedCollections) + { + var existingCollectionUser = existingCollectionUsers.FirstOrDefault(cu => cu.CollectionId == requestedCollection.Id); + if (existingCollectionUser == null) + { + // This is a brand new entry + dbContext.CollectionUsers.Add(new CollectionUser + { + CollectionId = requestedCollection.Id, + OrganizationUserId = obj.Id, + HidePasswords = requestedCollection.HidePasswords, + ReadOnly = requestedCollection.ReadOnly, + }); + break; + } - var insert = procedure.Insert.Run(dbContext); - await dbContext.AddRangeAsync(await insert.ToListAsync()); + // It already exists, update it + existingCollectionUser.HidePasswords = requestedCollection.HidePasswords; + existingCollectionUser.ReadOnly = requestedCollection.ReadOnly; + dbContext.CollectionUsers.Update(existingCollectionUser); + } - dbContext.RemoveRange(await procedure.Delete.Run(dbContext).ToListAsync()); + // Remove all existing ones that are no longer requested + var requestedCollectionIds = requestedCollections.Select(c => c.Id).ToList(); + dbContext.CollectionUsers.RemoveRange(existingCollectionUsers.Where(cu => !requestedCollectionIds.Contains(cu.CollectionId))); await dbContext.SaveChangesAsync(); } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs deleted file mode 100644 index 0a21514d62..0000000000 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/OrganizationUserUpdateWithCollectionsQuery.cs +++ /dev/null @@ -1,105 +0,0 @@ -using Bit.Core.Entities; -using Bit.Core.Models.Data; -using CollectionUser = Bit.Infrastructure.EntityFramework.Models.CollectionUser; - -namespace Bit.Infrastructure.EntityFramework.Repositories.Queries; - -public class OrganizationUserUpdateWithCollectionsQuery -{ - public OrganizationUserUpdateWithCollectionsInsertQuery Insert { get; set; } - public OrganizationUserUpdateWithCollectionsUpdateQuery Update { get; set; } - public OrganizationUserUpdateWithCollectionsDeleteQuery Delete { get; set; } - - public OrganizationUserUpdateWithCollectionsQuery(OrganizationUser organizationUser, - IEnumerable collections) - { - Insert = new OrganizationUserUpdateWithCollectionsInsertQuery(organizationUser, collections); - Update = new OrganizationUserUpdateWithCollectionsUpdateQuery(organizationUser, collections); - Delete = new OrganizationUserUpdateWithCollectionsDeleteQuery(organizationUser, collections); - } -} - -public class OrganizationUserUpdateWithCollectionsInsertQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsInsertQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var t = (from cu in dbContext.CollectionUsers - where collectionIds.Contains(cu.CollectionId) && - cu.OrganizationUserId == _organizationUser.Id - select cu).AsEnumerable(); - var insertQuery = (from c in dbContext.Collections - where collectionIds.Contains(c.Id) && - c.OrganizationId == _organizationUser.OrganizationId && - !t.Any() - select c).AsEnumerable(); - return insertQuery.Select(x => new CollectionUser - { - CollectionId = x.Id, - OrganizationUserId = _organizationUser.Id, - ReadOnly = _collections.FirstOrDefault(c => c.Id == x.Id).ReadOnly, - HidePasswords = _collections.FirstOrDefault(c => c.Id == x.Id).HidePasswords, - }).AsQueryable(); - } -} - -public class OrganizationUserUpdateWithCollectionsUpdateQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsUpdateQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var collectionIds = _collections.Select(c => c.Id).ToArray(); - var updateQuery = (from target in dbContext.CollectionUsers - where collectionIds.Contains(target.CollectionId) && - target.OrganizationUserId == _organizationUser.Id - select new { target }).AsEnumerable(); - updateQuery = updateQuery.Where(cu => - cu.target.ReadOnly == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).ReadOnly && - cu.target.HidePasswords == _collections.FirstOrDefault(u => u.Id == cu.target.CollectionId).HidePasswords); - return updateQuery.Select(x => new CollectionUser - { - CollectionId = x.target.CollectionId, - OrganizationUserId = _organizationUser.Id, - ReadOnly = x.target.ReadOnly, - HidePasswords = x.target.HidePasswords, - }).AsQueryable(); - } -} - -public class OrganizationUserUpdateWithCollectionsDeleteQuery : IQuery -{ - private readonly OrganizationUser _organizationUser; - private readonly IEnumerable _collections; - - public OrganizationUserUpdateWithCollectionsDeleteQuery(OrganizationUser organizationUser, IEnumerable collections) - { - _organizationUser = organizationUser; - _collections = collections; - } - - public IQueryable Run(DatabaseContext dbContext) - { - var deleteQuery = from cu in dbContext.CollectionUsers - where !_collections.Any( - c => c.Id == cu.CollectionId) - select cu; - return deleteQuery; - } -}