diff --git a/src/Core/Repositories/ICollectionRepository.cs b/src/Core/Repositories/ICollectionRepository.cs index 9e2f253c9f..70bda3eb13 100644 --- a/src/Core/Repositories/ICollectionRepository.cs +++ b/src/Core/Repositories/ICollectionRepository.cs @@ -62,4 +62,6 @@ public interface ICollectionRepository : IRepository Task DeleteManyAsync(IEnumerable collectionIds); Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable collectionIds, IEnumerable users, IEnumerable groups); + + Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName); } diff --git a/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs new file mode 100644 index 0000000000..139960ceba --- /dev/null +++ b/src/Infrastructure.Dapper/AdminConsole/Helpers/BulkResourceCreationService.cs @@ -0,0 +1,129 @@ +using System.Data; +using Bit.Core.Entities; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers; + +public static class BulkResourceCreationService +{ + private const string _defaultErrorMessage = "Must have at least one record for bulk creation."; + public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collectionUsers, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[CollectionUser]"; + var dataTable = BuildCollectionsUsersTable(bulkCopy, collectionUsers, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + private static DataTable BuildCollectionsUsersTable(SqlBulkCopy bulkCopy, IEnumerable collectionUsers, string errorMessage) + { + var collectionUser = collectionUsers.FirstOrDefault(); + + if (collectionUser == null) + { + throw new ApplicationException(errorMessage); + } + + var table = new DataTable("CollectionUserDataTable"); + + var collectionIdColumn = new DataColumn(nameof(collectionUser.CollectionId), collectionUser.CollectionId.GetType()); + table.Columns.Add(collectionIdColumn); + var orgUserIdColumn = new DataColumn(nameof(collectionUser.OrganizationUserId), collectionUser.OrganizationUserId.GetType()); + table.Columns.Add(orgUserIdColumn); + var readOnlyColumn = new DataColumn(nameof(collectionUser.ReadOnly), collectionUser.ReadOnly.GetType()); + table.Columns.Add(readOnlyColumn); + var hidePasswordsColumn = new DataColumn(nameof(collectionUser.HidePasswords), collectionUser.HidePasswords.GetType()); + table.Columns.Add(hidePasswordsColumn); + var manageColumn = new DataColumn(nameof(collectionUser.Manage), collectionUser.Manage.GetType()); + table.Columns.Add(manageColumn); + + foreach (DataColumn col in table.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[2]; + keys[0] = collectionIdColumn; + keys[1] = orgUserIdColumn; + table.PrimaryKey = keys; + + foreach (var collectionUserRecord in collectionUsers) + { + var row = table.NewRow(); + + row[collectionIdColumn] = collectionUserRecord.CollectionId; + row[orgUserIdColumn] = collectionUserRecord.OrganizationUserId; + row[readOnlyColumn] = collectionUserRecord.ReadOnly; + row[hidePasswordsColumn] = collectionUserRecord.HidePasswords; + row[manageColumn] = collectionUserRecord.Manage; + + table.Rows.Add(row); + } + + return table; + } + + public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable collections, string errorMessage = _defaultErrorMessage) + { + using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction); + bulkCopy.DestinationTableName = "[dbo].[Collection]"; + var dataTable = BuildCollectionsTable(bulkCopy, collections, errorMessage); + await bulkCopy.WriteToServerAsync(dataTable); + } + + private static DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable collections, string errorMessage) + { + var collection = collections.FirstOrDefault(); + + if (collection == null) + { + throw new ApplicationException(errorMessage); + } + + var collectionsTable = new DataTable("CollectionDataTable"); + + var idColumn = new DataColumn(nameof(collection.Id), collection.Id.GetType()); + collectionsTable.Columns.Add(idColumn); + var organizationIdColumn = new DataColumn(nameof(collection.OrganizationId), collection.OrganizationId.GetType()); + collectionsTable.Columns.Add(organizationIdColumn); + var nameColumn = new DataColumn(nameof(collection.Name), collection.Name.GetType()); + collectionsTable.Columns.Add(nameColumn); + var creationDateColumn = new DataColumn(nameof(collection.CreationDate), collection.CreationDate.GetType()); + collectionsTable.Columns.Add(creationDateColumn); + var revisionDateColumn = new DataColumn(nameof(collection.RevisionDate), collection.RevisionDate.GetType()); + collectionsTable.Columns.Add(revisionDateColumn); + var externalIdColumn = new DataColumn(nameof(collection.ExternalId), typeof(string)); + collectionsTable.Columns.Add(externalIdColumn); + var typeColumn = new DataColumn(nameof(collection.Type), collection.Type.GetType()); + collectionsTable.Columns.Add(typeColumn); + var defaultUserCollectionEmailColumn = new DataColumn(nameof(collection.DefaultUserCollectionEmail), typeof(string)); + collectionsTable.Columns.Add(defaultUserCollectionEmailColumn); + + foreach (DataColumn col in collectionsTable.Columns) + { + bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName); + } + + var keys = new DataColumn[1]; + keys[0] = idColumn; + collectionsTable.PrimaryKey = keys; + + foreach (var collectionRecord in collections) + { + var row = collectionsTable.NewRow(); + + row[idColumn] = collectionRecord.Id; + row[organizationIdColumn] = collectionRecord.OrganizationId; + row[nameColumn] = collectionRecord.Name; + row[creationDateColumn] = collectionRecord.CreationDate; + row[revisionDateColumn] = collectionRecord.RevisionDate; + row[externalIdColumn] = collectionRecord.ExternalId; + row[typeColumn] = collectionRecord.Type; + row[defaultUserCollectionEmailColumn] = collectionRecord.DefaultUserCollectionEmail; + + collectionsTable.Rows.Add(row); + } + + return collectionsTable; + } +} diff --git a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs index 6b71b57e3d..77fbdff3ae 100644 --- a/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/CollectionRepository.cs @@ -2,9 +2,11 @@ using System.Diagnostics.CodeAnalysis; using System.Text.Json; using Bit.Core.Entities; +using Bit.Core.Enums; using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Core.Settings; +using Bit.Infrastructure.Dapper.AdminConsole.Helpers; using Dapper; using Microsoft.Data.SqlClient; @@ -222,6 +224,8 @@ public class CollectionRepository : Repository, ICollectionRep public async Task CreateAsync(Collection obj, IEnumerable? groups, IEnumerable? users) { obj.SetNewId(); + + var objWithGroupsAndUsers = JsonSerializer.Deserialize(JsonSerializer.Serialize(obj))!; objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty().ToArrayTVP(); @@ -322,6 +326,100 @@ public class CollectionRepository : Repository, ICollectionRep } } + public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName) + { + if (!affectedOrgUserIds.Any()) + { + return; + } + + await using var connection = new SqlConnection(ConnectionString); + connection.Open(); + await using var transaction = connection.BeginTransaction(); + try + { + var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId); + + var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection); + + var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + + if (!collectionUsers.Any() || !collections.Any()) + { + return; + } + + await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections); + await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers); + + transaction.Commit(); + } + catch + { + transaction.Rollback(); + throw; + } + } + + private async Task> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId) + { + const string sql = @" + SELECT + ou.Id AS OrganizationUserId + FROM + OrganizationUser ou + INNER JOIN + CollectionUser cu ON cu.OrganizationUserId = ou.Id + INNER JOIN + Collection c ON c.Id = cu.CollectionId + WHERE + ou.OrganizationId = @OrganizationId + AND c.Type = @CollectionType; + "; + + var organizationUserIds = await connection.QueryAsync( + sql, + new { OrganizationId = organizationId, CollectionType = CollectionType.DefaultUserCollection }, + transaction: transaction + ); + + return organizationUserIds.ToHashSet(); + } + + private (List collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) + { + var collectionUsers = new List(); + var collections = new List(); + + foreach (var orgUserId in missingDefaultCollectionUserIds) + { + var collectionId = Guid.NewGuid(); + + collections.Add(new Collection + { + Id = collectionId, + OrganizationId = organizationId, + Name = defaultCollectionName, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + Type = CollectionType.DefaultUserCollection, + DefaultUserCollectionEmail = null + + }); + + collectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = orgUserId, + ReadOnly = false, + HidePasswords = false, + Manage = true, + }); + } + + return (collectionUsers, collections); + } + public class CollectionWithGroupsAndUsers : Collection { [DisallowNull] diff --git a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs index 3169f86420..9f047e4653 100644 --- a/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/CollectionRepository.cs @@ -4,6 +4,7 @@ using Bit.Core.Models.Data; using Bit.Core.Repositories; using Bit.Infrastructure.EntityFramework.Models; using Bit.Infrastructure.EntityFramework.Repositories.Queries; +using LinqToDB.EntityFrameworkCore; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; @@ -256,7 +257,8 @@ public class CollectionRepository : Repository new CollectionDetails { @@ -269,6 +271,7 @@ public class CollectionRepository : Repository Convert.ToInt32(c.ReadOnly))), HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))), Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))), + Type = collectionGroup.Key.Type, }) .ToList(); } @@ -281,7 +284,8 @@ public class CollectionRepository : Repository Convert.ToInt32(c.ReadOnly))), HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))), Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))), + Type = collectionGroup.Key.Type, }).ToListAsync(); } } @@ -711,6 +716,7 @@ public class CollectionRepository : Repository groups) { var existingCollectionGroups = await dbContext.CollectionGroups @@ -782,4 +788,88 @@ public class CollectionRepository : Repository affectedOrgUserIds, string defaultCollectionName) + { + if (!affectedOrgUserIds.Any()) + { + return; + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId); + + var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection); + + var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName); + + if (!collectionUsers.Any() || !collections.Any()) + { + return; + } + + await dbContext.BulkCopyAsync(collections); + await dbContext.BulkCopyAsync(collectionUsers); + + await dbContext.SaveChangesAsync(); + } + + private async Task> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId) + { + var results = await dbContext.OrganizationUsers + .Where(ou => ou.OrganizationId == organizationId) + .Join( + dbContext.CollectionUsers, + ou => ou.Id, + cu => cu.OrganizationUserId, + (ou, cu) => new { ou, cu } + ) + .Join( + dbContext.Collections, + temp => temp.cu.CollectionId, + c => c.Id, + (temp, c) => new { temp.ou, Collection = c } + ) + .Where(x => x.Collection.Type == CollectionType.DefaultUserCollection) + .Select(x => x.ou.Id) + .ToListAsync(); + + return results.ToHashSet(); + } + + private (List collectionUser, List collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable missingDefaultCollectionUserIds, string defaultCollectionName) + { + var collectionUsers = new List(); + var collections = new List(); + + foreach (var orgUserId in missingDefaultCollectionUserIds) + { + var collectionId = Guid.NewGuid(); + + collections.Add(new Collection + { + Id = collectionId, + OrganizationId = organizationId, + Name = defaultCollectionName, + CreationDate = DateTime.UtcNow, + RevisionDate = DateTime.UtcNow, + Type = CollectionType.DefaultUserCollection, + DefaultUserCollectionEmail = null + + }); + + collectionUsers.Add(new CollectionUser + { + CollectionId = collectionId, + OrganizationUserId = orgUserId, + ReadOnly = false, + HidePasswords = false, + Manage = true, + }); + } + + return (collectionUsers, collections); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs index 6e513e8098..14dd8c876c 100644 --- a/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs +++ b/src/Infrastructure.EntityFramework/Repositories/Queries/UserCollectionDetailsQuery.cs @@ -47,17 +47,18 @@ public class UserCollectionDetailsQuery : IQuery ((cu == null ? (Guid?)null : cu.CollectionId) != null || (cg == null ? (Guid?)null : cg.CollectionId) != null) select new { c, ou, o, cu, gu, g, cg }; - return query.Select(x => new CollectionDetails + return query.Select(row => new CollectionDetails { - Id = x.c.Id, - OrganizationId = x.c.OrganizationId, - Name = x.c.Name, - ExternalId = x.c.ExternalId, - CreationDate = x.c.CreationDate, - RevisionDate = x.c.RevisionDate, - ReadOnly = (bool?)x.cu.ReadOnly ?? (bool?)x.cg.ReadOnly ?? false, - HidePasswords = (bool?)x.cu.HidePasswords ?? (bool?)x.cg.HidePasswords ?? false, - Manage = (bool?)x.cu.Manage ?? (bool?)x.cg.Manage ?? false, + Id = row.c.Id, + OrganizationId = row.c.OrganizationId, + Name = row.c.Name, + ExternalId = row.c.ExternalId, + CreationDate = row.c.CreationDate, + RevisionDate = row.c.RevisionDate, + ReadOnly = (bool?)row.cu.ReadOnly ?? (bool?)row.cg.ReadOnly ?? false, + HidePasswords = (bool?)row.cu.HidePasswords ?? (bool?)row.cg.HidePasswords ?? false, + Manage = (bool?)row.cu.Manage ?? (bool?)row.cg.Manage ?? false, + Type = row.c.Type }); } } diff --git a/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs new file mode 100644 index 0000000000..d85cc1e813 --- /dev/null +++ b/test/Infrastructure.IntegrationTest/AdminConsole/Repositories/CollectionRepository/CreateDefaultCollectionsTests.cs @@ -0,0 +1,154 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Entities; +using Bit.Core.Enums; +using Bit.Core.Repositories; +using Xunit; + +namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository; + +public class CreateDefaultCollectionsTests +{ + [DatabaseTheory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + // Arrange + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var resultOrganizationUsers = await Task.WhenAll( + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization), + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) + ); + + + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var defaultCollectionName = $"default-name-{organization.Id}"; + + // Act + await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + + // Assert + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); + + await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); + } + + [DatabaseTheory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + // Arrange + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var arrangedOrganizationUsers = await Task.WhenAll( + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization), + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) + ); + + var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id); + var defaultCollectionName = $"default-name-{organization.Id}"; + + + await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers); + + var newOrganizationUsers = new List() + { + await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) + }; + + var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers); + var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id); + + // Act + await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + + // Assert + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id); + + await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers); + } + + [DatabaseTheory, DatabaseData] + public async Task CreateDefaultCollectionsAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne( + IOrganizationRepository organizationRepository, + IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, + ICollectionRepository collectionRepository) + { + // Arrange + var organization = await organizationRepository.CreateTestOrganizationAsync(); + + var resultOrganizationUsers = await Task.WhenAll( + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization), + CreateUserForOrgAsync(userRepository, organizationUserRepository, organization) + ); + + var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id); + var defaultCollectionName = $"default-name-{organization.Id}"; + + + await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers); + + // Act + await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName); + + // Assert + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id); + + await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers); + } + + private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository, + Guid organizationId, IEnumerable affectedOrgUserIds, string defaultCollectionName, + OrganizationUser[] resultOrganizationUsers) + { + await collectionRepository.CreateDefaultCollectionsAsync(organizationId, affectedOrgUserIds, defaultCollectionName); + + await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId); + } + + private static async Task AssertAllUsersHaveOneDefaultCollectionAsync(ICollectionRepository collectionRepository, + IEnumerable organizationUsers, Guid organizationId) + { + foreach (var organizationUser in organizationUsers) + { + var collectionDetails = await collectionRepository.GetManyByUserIdAsync(organizationUser!.UserId.Value); + var defaultCollection = collectionDetails + .SingleOrDefault(collectionDetail => + collectionDetail.OrganizationId == organizationId + && collectionDetail.Type == CollectionType.DefaultUserCollection); + + Assert.NotNull(defaultCollection); + } + } + + private static async Task CreateUserForOrgAsync(IUserRepository userRepository, + IOrganizationUserRepository organizationUserRepository, Organization organization) + { + + var user = await userRepository.CreateTestUserAsync(); + var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user); + + return orgUser; + } + + private static async Task CleanupAsync(IOrganizationRepository organizationRepository, + IUserRepository userRepository, + Organization organization, + IEnumerable organizationUsers) + { + await organizationRepository.DeleteAsync(organization); + + await userRepository.DeleteManyAsync( + organizationUsers + .Where(organizationUser => organizationUser.UserId != null) + .Select(organizationUser => new User() { Id = organizationUser.UserId.Value }) + ); + } +}