diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index a0bf2fc98d..fc3a6715ac 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -47,7 +47,7 @@ public interface ICollectionRepository : IRepository /// Task GetByIdWithPermissionsAsync(Guid collectionId, Guid? userId, bool includeAccessRelationships); - Task CreateAsync(Collection obj, IEnumerable groups, IEnumerable users); + Task CreateAsync(Collection obj, IEnumerable? groups, IEnumerable? users); Task ReplaceAsync(Collection obj, IEnumerable groups, IEnumerable users); Task DeleteUserAsync(Guid collectionId, Guid organizationUserId); Task UpdateUsersAsync(Guid id, IEnumerable users); diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index cdc5caf366..0bb6fb9925 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -50,7 +50,7 @@ public class CollectionRepository : Repository groups, IEnumerable users) + public async Task CreateAsync(Core.Entities.Collection obj, IEnumerable? groups, IEnumerable? users) { await CreateAsync(obj); using (var scope = ServiceScopeFactory.CreateScope()) @@ -523,6 +523,7 @@ public class CollectionRepository : Repository groups) + private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable groups) { - var groupsInOrg = dbContext.Groups.Where(g => g.OrganizationId == collection.OrganizationId); - var modifiedGroupEntities = dbContext.Groups.Where(x => groups.Select(x => x.Id).Contains(x.Id)); - var target = (from cg in dbContext.CollectionGroups - join g in modifiedGroupEntities - on cg.CollectionId equals collection.Id into s_g - from g in s_g.DefaultIfEmpty() - where g == null || cg.GroupId == g.Id - select new { cg, g }).AsNoTracking(); - var source = (from g in modifiedGroupEntities - from cg in dbContext.CollectionGroups - .Where(cg => cg.CollectionId == collection.Id && cg.GroupId == g.Id).DefaultIfEmpty() - select new { cg, g }).AsNoTracking(); - var union = await target - .Union(source) - .Where(x => - x.cg == null || - ((x.g == null || x.g.Id == x.cg.GroupId) && - (x.cg.CollectionId == collection.Id))) - .AsNoTracking() - .ToListAsync(); - var insert = union.Where(x => x.cg == null && groupsInOrg.Any(c => x.g.Id == c.Id)) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, - Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage - }).ToList(); - var update = union - .Where( - x => x.g != null && - x.cg != null && - (x.cg.ReadOnly != groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly || - x.cg.HidePasswords != groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords || - x.cg.Manage != groups.FirstOrDefault(g => g.Id == x.g.Id).Manage) - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.g.Id, - ReadOnly = groups.FirstOrDefault(g => g.Id == x.g.Id).ReadOnly, - HidePasswords = groups.FirstOrDefault(g => g.Id == x.g.Id).HidePasswords, - Manage = groups.FirstOrDefault(g => g.Id == x.g.Id).Manage, - }); - var delete = union - .Where( - x => x.g == null && - x.cg.CollectionId == collection.Id - ) - .Select(x => new CollectionGroup - { - CollectionId = collection.Id, - GroupId = x.cg.GroupId, - }) - .ToList(); + var existingCollectionGroups = await dbContext.CollectionGroups + .Where(cg => cg.CollectionId == collection.Id) + .ToDictionaryAsync(cg => cg.GroupId); - await dbContext.AddRangeAsync(insert); - dbContext.UpdateRange(update); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); + foreach (var group in groups) + { + if (existingCollectionGroups.TryGetValue(group.Id, out var existingCollectionGroup)) + { + // It already exists, update it + existingCollectionGroup.HidePasswords = group.HidePasswords; + existingCollectionGroup.ReadOnly = group.ReadOnly; + existingCollectionGroup.Manage = group.Manage; + dbContext.CollectionGroups.Update(existingCollectionGroup); + } + else + { + // This is a brand new entry, add it + dbContext.CollectionGroups.Add(new CollectionGroup + { + GroupId = group.Id, + CollectionId = collection.Id, + HidePasswords = group.HidePasswords, + ReadOnly = group.ReadOnly, + Manage = group.Manage, + }); + } + } + + var requestedGroupIds = groups.Select(g => g.Id).ToArray(); + var toDelete = existingCollectionGroups.Values.Where(cg => !requestedGroupIds.Contains(cg.GroupId)); + dbContext.CollectionGroups.RemoveRange(toDelete); + // SaveChangesAsync is expected to be called outside this method } - private async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable users) + private static async Task ReplaceCollectionUsersAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable users) { - var usersInOrg = dbContext.OrganizationUsers.Where(u => u.OrganizationId == collection.OrganizationId); - var modifiedUserEntities = dbContext.OrganizationUsers.Where(x => users.Select(x => x.Id).Contains(x.Id)); - var target = (from cu in dbContext.CollectionUsers - join u in modifiedUserEntities - on cu.CollectionId equals collection.Id into s_g - from u in s_g.DefaultIfEmpty() - where u == null || cu.OrganizationUserId == u.Id - select new { cu, u }).AsNoTracking(); - var source = (from u in modifiedUserEntities - from cu in dbContext.CollectionUsers - .Where(cu => cu.CollectionId == collection.Id && cu.OrganizationUserId == u.Id).DefaultIfEmpty() - select new { cu, u }).AsNoTracking(); - var union = await target - .Union(source) - .Where(x => - x.cu == null || - ((x.u == null || x.u.Id == x.cu.OrganizationUserId) && - (x.cu.CollectionId == collection.Id))) - .AsNoTracking() - .ToListAsync(); - var insert = union.Where(x => x.u == null && usersInOrg.Any(c => x.u.Id == c.Id)) - .Select(x => new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = x.u.Id, - ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly, - HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords, - Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage, - }).ToList(); - var update = union - .Where( - x => x.u != null && - x.cu != null && - (x.cu.ReadOnly != users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly || - x.cu.HidePasswords != users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords || - x.cu.Manage != users.FirstOrDefault(u => u.Id == x.u.Id).Manage) - ) - .Select(x => new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = x.u.Id, - ReadOnly = users.FirstOrDefault(u => u.Id == x.u.Id).ReadOnly, - HidePasswords = users.FirstOrDefault(u => u.Id == x.u.Id).HidePasswords, - Manage = users.FirstOrDefault(u => u.Id == x.u.Id).Manage, - }); - var delete = union - .Where( - x => x.u == null && - x.cu.CollectionId == collection.Id - ) - .Select(x => new CollectionUser - { - CollectionId = collection.Id, - OrganizationUserId = x.cu.OrganizationUserId, - }) - .ToList(); + var existingCollectionUsers = await dbContext.CollectionUsers + .Where(cu => cu.CollectionId == collection.Id) + .ToDictionaryAsync(cu => cu.OrganizationUserId); - await dbContext.AddRangeAsync(insert); - dbContext.UpdateRange(update); - dbContext.RemoveRange(delete); - await dbContext.SaveChangesAsync(); + foreach (var user in users) + { + if (existingCollectionUsers.TryGetValue(user.Id, out var existingCollectionUser)) + { + // This is an existing entry, update it. + existingCollectionUser.HidePasswords = user.HidePasswords; + existingCollectionUser.ReadOnly = user.ReadOnly; + existingCollectionUser.Manage = user.Manage; + dbContext.CollectionUsers.Update(existingCollectionUser); + } + else + { + // This is a brand new entry, add it + dbContext.CollectionUsers.Add(new CollectionUser + { + OrganizationUserId = user.Id, + CollectionId = collection.Id, + HidePasswords = user.HidePasswords, + ReadOnly = user.ReadOnly, + Manage = user.Manage, + }); + } + } + + var requestedUserIds = users.Select(u => u.Id).ToArray(); + var toDelete = existingCollectionUsers.Values.Where(cu => !requestedUserIds.Contains(cu.OrganizationUserId)); + dbContext.CollectionUsers.RemoveRange(toDelete); + // SaveChangesAsync is expected to be called outside this method } } diff --git a/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj b/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj index b8feda50af..6aafe44ca8 100644 --- a/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj +++ b/test/Infrastructure.IntegrationTest/Infrastructure.IntegrationTest.csproj @@ -13,8 +13,8 @@ - - + + runtime; build; native; contentfiles; analyzers; buildtransitive all diff --git a/test/Infrastructure.IntegrationTest/Vault/Repositories/CollectionRepositoryTests.cs b/test/Infrastructure.IntegrationTest/Vault/Repositories/CollectionRepositoryTests.cs index e984c8326f..fa7197ff61 100644 --- a/test/Infrastructure.IntegrationTest/Vault/Repositories/CollectionRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/Vault/Repositories/CollectionRepositoryTests.cs @@ -463,4 +463,141 @@ public class CollectionRepositoryTests Assert.False(c3.Unmanaged); }); } + + [DatabaseTheory, DatabaseData] + public async Task ReplaceAsync_Works( + IUserRepository userRepository, + IOrganizationRepository organizationRepository, + IOrganizationUserRepository organizationUserRepository, + IGroupRepository groupRepository, + ICollectionRepository collectionRepository) + { + var user = await userRepository.CreateAsync(new User + { + Name = "Test User", + Email = $"test+{Guid.NewGuid()}@email.com", + ApiKey = "TEST", + SecurityStamp = "stamp", + }); + + var organization = await organizationRepository.CreateAsync(new Organization + { + Name = "Test Org", + PlanType = PlanType.EnterpriseAnnually, + Plan = "Test Plan", + BillingEmail = "billing@email.com" + }); + + var orgUser1 = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Confirmed, + }); + + var orgUser2 = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Confirmed, + }); + + var orgUser3 = await organizationUserRepository.CreateAsync(new OrganizationUser + { + OrganizationId = organization.Id, + UserId = user.Id, + Status = OrganizationUserStatusType.Confirmed, + }); + + var group1 = await groupRepository.CreateAsync(new Group + { + Name = "Test Group #1", + OrganizationId = organization.Id, + }); + + var group2 = await groupRepository.CreateAsync(new Group + { + Name = "Test Group #2", + OrganizationId = organization.Id, + }); + + var group3 = await groupRepository.CreateAsync(new Group + { + Name = "Test Group #3", + OrganizationId = organization.Id, + }); + + var collection = new Collection + { + Name = "Test Collection Name", + OrganizationId = organization.Id, + }; + + await collectionRepository.CreateAsync(collection, + [ + new CollectionAccessSelection { Id = group1.Id, Manage = true, HidePasswords = true, ReadOnly = false, }, + new CollectionAccessSelection { Id = group2.Id, Manage = false, HidePasswords = false, ReadOnly = true, }, + ], + [ + new CollectionAccessSelection { Id = orgUser1.Id, Manage = true, HidePasswords = false, ReadOnly = true }, + new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = true, ReadOnly = false }, + ] + ); + + collection.Name = "Updated Collection Name"; + + await collectionRepository.ReplaceAsync(collection, + [ + // Should delete group1 + new CollectionAccessSelection { Id = group2.Id, Manage = true, HidePasswords = true, ReadOnly = false, }, + // Should add group3 + new CollectionAccessSelection { Id = group3.Id, Manage = false, HidePasswords = false, ReadOnly = true, }, + ], + [ + // Should delete orgUser1 + new CollectionAccessSelection { Id = orgUser2.Id, Manage = false, HidePasswords = false, ReadOnly = true }, + // Should add orgUser3 + new CollectionAccessSelection { Id = orgUser3.Id, Manage = true, HidePasswords = false, ReadOnly = true }, + ] + ); + + // Assert it + var info = await collectionRepository.GetByIdWithPermissionsAsync(collection.Id, user.Id, true); + + Assert.NotNull(info); + + Assert.Equal("Updated Collection Name", info.Name); + + var groups = info.Groups.ToArray(); + + Assert.Equal(2, groups.Length); + + var actualGroup2 = Assert.Single(groups.Where(g => g.Id == group2.Id)); + + Assert.True(actualGroup2.Manage); + Assert.True(actualGroup2.HidePasswords); + Assert.False(actualGroup2.ReadOnly); + + var actualGroup3 = Assert.Single(groups.Where(g => g.Id == group3.Id)); + + Assert.False(actualGroup3.Manage); + Assert.False(actualGroup3.HidePasswords); + Assert.True(actualGroup3.ReadOnly); + + var users = info.Users.ToArray(); + + Assert.Equal(2, users.Length); + + var actualOrgUser2 = Assert.Single(users.Where(u => u.Id == orgUser2.Id)); + + Assert.False(actualOrgUser2.Manage); + Assert.False(actualOrgUser2.HidePasswords); + Assert.True(actualOrgUser2.ReadOnly); + + var actualOrgUser3 = Assert.Single(users.Where(u => u.Id == orgUser3.Id)); + + Assert.True(actualOrgUser3.Manage); + Assert.False(actualOrgUser3.HidePasswords); + Assert.True(actualOrgUser3.ReadOnly); + } }