mirror of
https://github.com/bitwarden/server
synced 2025-12-14 23:33:41 +00:00
Add bulk default collection creation method (#6075)
This commit is contained in:
@@ -62,4 +62,6 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
|
|||||||
Task DeleteManyAsync(IEnumerable<Guid> collectionIds);
|
Task DeleteManyAsync(IEnumerable<Guid> collectionIds);
|
||||||
Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable<Guid> collectionIds,
|
Task CreateOrUpdateAccessForManyAsync(Guid organizationId, IEnumerable<Guid> collectionIds,
|
||||||
IEnumerable<CollectionAccessSelection> users, IEnumerable<CollectionAccessSelection> groups);
|
IEnumerable<CollectionAccessSelection> users, IEnumerable<CollectionAccessSelection> groups);
|
||||||
|
|
||||||
|
Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,129 @@
|
|||||||
|
using System.Data;
|
||||||
|
using Bit.Core.Entities;
|
||||||
|
using Microsoft.Data.SqlClient;
|
||||||
|
|
||||||
|
namespace Bit.Infrastructure.Dapper.AdminConsole.Helpers;
|
||||||
|
|
||||||
|
public static class BulkResourceCreationService
|
||||||
|
{
|
||||||
|
private const string _defaultErrorMessage = "Must have at least one record for bulk creation.";
|
||||||
|
public static async Task CreateCollectionsUsersAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable<CollectionUser> collectionUsers, string errorMessage = _defaultErrorMessage)
|
||||||
|
{
|
||||||
|
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
|
||||||
|
bulkCopy.DestinationTableName = "[dbo].[CollectionUser]";
|
||||||
|
var dataTable = BuildCollectionsUsersTable(bulkCopy, collectionUsers, errorMessage);
|
||||||
|
await bulkCopy.WriteToServerAsync(dataTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static DataTable BuildCollectionsUsersTable(SqlBulkCopy bulkCopy, IEnumerable<CollectionUser> collectionUsers, string errorMessage)
|
||||||
|
{
|
||||||
|
var collectionUser = collectionUsers.FirstOrDefault();
|
||||||
|
|
||||||
|
if (collectionUser == null)
|
||||||
|
{
|
||||||
|
throw new ApplicationException(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
var table = new DataTable("CollectionUserDataTable");
|
||||||
|
|
||||||
|
var collectionIdColumn = new DataColumn(nameof(collectionUser.CollectionId), collectionUser.CollectionId.GetType());
|
||||||
|
table.Columns.Add(collectionIdColumn);
|
||||||
|
var orgUserIdColumn = new DataColumn(nameof(collectionUser.OrganizationUserId), collectionUser.OrganizationUserId.GetType());
|
||||||
|
table.Columns.Add(orgUserIdColumn);
|
||||||
|
var readOnlyColumn = new DataColumn(nameof(collectionUser.ReadOnly), collectionUser.ReadOnly.GetType());
|
||||||
|
table.Columns.Add(readOnlyColumn);
|
||||||
|
var hidePasswordsColumn = new DataColumn(nameof(collectionUser.HidePasswords), collectionUser.HidePasswords.GetType());
|
||||||
|
table.Columns.Add(hidePasswordsColumn);
|
||||||
|
var manageColumn = new DataColumn(nameof(collectionUser.Manage), collectionUser.Manage.GetType());
|
||||||
|
table.Columns.Add(manageColumn);
|
||||||
|
|
||||||
|
foreach (DataColumn col in table.Columns)
|
||||||
|
{
|
||||||
|
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys = new DataColumn[2];
|
||||||
|
keys[0] = collectionIdColumn;
|
||||||
|
keys[1] = orgUserIdColumn;
|
||||||
|
table.PrimaryKey = keys;
|
||||||
|
|
||||||
|
foreach (var collectionUserRecord in collectionUsers)
|
||||||
|
{
|
||||||
|
var row = table.NewRow();
|
||||||
|
|
||||||
|
row[collectionIdColumn] = collectionUserRecord.CollectionId;
|
||||||
|
row[orgUserIdColumn] = collectionUserRecord.OrganizationUserId;
|
||||||
|
row[readOnlyColumn] = collectionUserRecord.ReadOnly;
|
||||||
|
row[hidePasswordsColumn] = collectionUserRecord.HidePasswords;
|
||||||
|
row[manageColumn] = collectionUserRecord.Manage;
|
||||||
|
|
||||||
|
table.Rows.Add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
return table;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async Task CreateCollectionsAsync(SqlConnection connection, SqlTransaction transaction, IEnumerable<Collection> collections, string errorMessage = _defaultErrorMessage)
|
||||||
|
{
|
||||||
|
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity, transaction);
|
||||||
|
bulkCopy.DestinationTableName = "[dbo].[Collection]";
|
||||||
|
var dataTable = BuildCollectionsTable(bulkCopy, collections, errorMessage);
|
||||||
|
await bulkCopy.WriteToServerAsync(dataTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static DataTable BuildCollectionsTable(SqlBulkCopy bulkCopy, IEnumerable<Collection> collections, string errorMessage)
|
||||||
|
{
|
||||||
|
var collection = collections.FirstOrDefault();
|
||||||
|
|
||||||
|
if (collection == null)
|
||||||
|
{
|
||||||
|
throw new ApplicationException(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
var collectionsTable = new DataTable("CollectionDataTable");
|
||||||
|
|
||||||
|
var idColumn = new DataColumn(nameof(collection.Id), collection.Id.GetType());
|
||||||
|
collectionsTable.Columns.Add(idColumn);
|
||||||
|
var organizationIdColumn = new DataColumn(nameof(collection.OrganizationId), collection.OrganizationId.GetType());
|
||||||
|
collectionsTable.Columns.Add(organizationIdColumn);
|
||||||
|
var nameColumn = new DataColumn(nameof(collection.Name), collection.Name.GetType());
|
||||||
|
collectionsTable.Columns.Add(nameColumn);
|
||||||
|
var creationDateColumn = new DataColumn(nameof(collection.CreationDate), collection.CreationDate.GetType());
|
||||||
|
collectionsTable.Columns.Add(creationDateColumn);
|
||||||
|
var revisionDateColumn = new DataColumn(nameof(collection.RevisionDate), collection.RevisionDate.GetType());
|
||||||
|
collectionsTable.Columns.Add(revisionDateColumn);
|
||||||
|
var externalIdColumn = new DataColumn(nameof(collection.ExternalId), typeof(string));
|
||||||
|
collectionsTable.Columns.Add(externalIdColumn);
|
||||||
|
var typeColumn = new DataColumn(nameof(collection.Type), collection.Type.GetType());
|
||||||
|
collectionsTable.Columns.Add(typeColumn);
|
||||||
|
var defaultUserCollectionEmailColumn = new DataColumn(nameof(collection.DefaultUserCollectionEmail), typeof(string));
|
||||||
|
collectionsTable.Columns.Add(defaultUserCollectionEmailColumn);
|
||||||
|
|
||||||
|
foreach (DataColumn col in collectionsTable.Columns)
|
||||||
|
{
|
||||||
|
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys = new DataColumn[1];
|
||||||
|
keys[0] = idColumn;
|
||||||
|
collectionsTable.PrimaryKey = keys;
|
||||||
|
|
||||||
|
foreach (var collectionRecord in collections)
|
||||||
|
{
|
||||||
|
var row = collectionsTable.NewRow();
|
||||||
|
|
||||||
|
row[idColumn] = collectionRecord.Id;
|
||||||
|
row[organizationIdColumn] = collectionRecord.OrganizationId;
|
||||||
|
row[nameColumn] = collectionRecord.Name;
|
||||||
|
row[creationDateColumn] = collectionRecord.CreationDate;
|
||||||
|
row[revisionDateColumn] = collectionRecord.RevisionDate;
|
||||||
|
row[externalIdColumn] = collectionRecord.ExternalId;
|
||||||
|
row[typeColumn] = collectionRecord.Type;
|
||||||
|
row[defaultUserCollectionEmailColumn] = collectionRecord.DefaultUserCollectionEmail;
|
||||||
|
|
||||||
|
collectionsTable.Rows.Add(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
return collectionsTable;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,9 +2,11 @@
|
|||||||
using System.Diagnostics.CodeAnalysis;
|
using System.Diagnostics.CodeAnalysis;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using Bit.Core.Entities;
|
using Bit.Core.Entities;
|
||||||
|
using Bit.Core.Enums;
|
||||||
using Bit.Core.Models.Data;
|
using Bit.Core.Models.Data;
|
||||||
using Bit.Core.Repositories;
|
using Bit.Core.Repositories;
|
||||||
using Bit.Core.Settings;
|
using Bit.Core.Settings;
|
||||||
|
using Bit.Infrastructure.Dapper.AdminConsole.Helpers;
|
||||||
using Dapper;
|
using Dapper;
|
||||||
using Microsoft.Data.SqlClient;
|
using Microsoft.Data.SqlClient;
|
||||||
|
|
||||||
@@ -222,6 +224,8 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
|||||||
public async Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
|
public async Task CreateAsync(Collection obj, IEnumerable<CollectionAccessSelection>? groups, IEnumerable<CollectionAccessSelection>? users)
|
||||||
{
|
{
|
||||||
obj.SetNewId();
|
obj.SetNewId();
|
||||||
|
|
||||||
|
|
||||||
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
|
var objWithGroupsAndUsers = JsonSerializer.Deserialize<CollectionWithGroupsAndUsers>(JsonSerializer.Serialize(obj))!;
|
||||||
|
|
||||||
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
|
objWithGroupsAndUsers.Groups = groups != null ? groups.ToArrayTVP() : Enumerable.Empty<CollectionAccessSelection>().ToArrayTVP();
|
||||||
@@ -322,6 +326,100 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName)
|
||||||
|
{
|
||||||
|
if (!affectedOrgUserIds.Any())
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await using var connection = new SqlConnection(ConnectionString);
|
||||||
|
connection.Open();
|
||||||
|
await using var transaction = connection.BeginTransaction();
|
||||||
|
try
|
||||||
|
{
|
||||||
|
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
|
||||||
|
|
||||||
|
var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection);
|
||||||
|
|
||||||
|
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
if (!collectionUsers.Any() || !collections.Any())
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await BulkResourceCreationService.CreateCollectionsAsync(connection, transaction, collections);
|
||||||
|
await BulkResourceCreationService.CreateCollectionsUsersAsync(connection, transaction, collectionUsers);
|
||||||
|
|
||||||
|
transaction.Commit();
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
transaction.Rollback();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId)
|
||||||
|
{
|
||||||
|
const string sql = @"
|
||||||
|
SELECT
|
||||||
|
ou.Id AS OrganizationUserId
|
||||||
|
FROM
|
||||||
|
OrganizationUser ou
|
||||||
|
INNER JOIN
|
||||||
|
CollectionUser cu ON cu.OrganizationUserId = ou.Id
|
||||||
|
INNER JOIN
|
||||||
|
Collection c ON c.Id = cu.CollectionId
|
||||||
|
WHERE
|
||||||
|
ou.OrganizationId = @OrganizationId
|
||||||
|
AND c.Type = @CollectionType;
|
||||||
|
";
|
||||||
|
|
||||||
|
var organizationUserIds = await connection.QueryAsync<Guid>(
|
||||||
|
sql,
|
||||||
|
new { OrganizationId = organizationId, CollectionType = CollectionType.DefaultUserCollection },
|
||||||
|
transaction: transaction
|
||||||
|
);
|
||||||
|
|
||||||
|
return organizationUserIds.ToHashSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
private (List<CollectionUser> collectionUser, List<Collection> collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable<Guid> missingDefaultCollectionUserIds, string defaultCollectionName)
|
||||||
|
{
|
||||||
|
var collectionUsers = new List<CollectionUser>();
|
||||||
|
var collections = new List<Collection>();
|
||||||
|
|
||||||
|
foreach (var orgUserId in missingDefaultCollectionUserIds)
|
||||||
|
{
|
||||||
|
var collectionId = Guid.NewGuid();
|
||||||
|
|
||||||
|
collections.Add(new Collection
|
||||||
|
{
|
||||||
|
Id = collectionId,
|
||||||
|
OrganizationId = organizationId,
|
||||||
|
Name = defaultCollectionName,
|
||||||
|
CreationDate = DateTime.UtcNow,
|
||||||
|
RevisionDate = DateTime.UtcNow,
|
||||||
|
Type = CollectionType.DefaultUserCollection,
|
||||||
|
DefaultUserCollectionEmail = null
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
collectionUsers.Add(new CollectionUser
|
||||||
|
{
|
||||||
|
CollectionId = collectionId,
|
||||||
|
OrganizationUserId = orgUserId,
|
||||||
|
ReadOnly = false,
|
||||||
|
HidePasswords = false,
|
||||||
|
Manage = true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return (collectionUsers, collections);
|
||||||
|
}
|
||||||
|
|
||||||
public class CollectionWithGroupsAndUsers : Collection
|
public class CollectionWithGroupsAndUsers : Collection
|
||||||
{
|
{
|
||||||
[DisallowNull]
|
[DisallowNull]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ using Bit.Core.Models.Data;
|
|||||||
using Bit.Core.Repositories;
|
using Bit.Core.Repositories;
|
||||||
using Bit.Infrastructure.EntityFramework.Models;
|
using Bit.Infrastructure.EntityFramework.Models;
|
||||||
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
|
using Bit.Infrastructure.EntityFramework.Repositories.Queries;
|
||||||
|
using LinqToDB.EntityFrameworkCore;
|
||||||
using Microsoft.EntityFrameworkCore;
|
using Microsoft.EntityFrameworkCore;
|
||||||
using Microsoft.Extensions.DependencyInjection;
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
|
||||||
@@ -256,7 +257,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
c.Name,
|
c.Name,
|
||||||
c.CreationDate,
|
c.CreationDate,
|
||||||
c.RevisionDate,
|
c.RevisionDate,
|
||||||
c.ExternalId
|
c.ExternalId,
|
||||||
|
c.Type
|
||||||
})
|
})
|
||||||
.Select(collectionGroup => new CollectionDetails
|
.Select(collectionGroup => new CollectionDetails
|
||||||
{
|
{
|
||||||
@@ -269,6 +271,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
ReadOnly = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.ReadOnly))),
|
ReadOnly = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.ReadOnly))),
|
||||||
HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))),
|
HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))),
|
||||||
Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))),
|
Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))),
|
||||||
|
Type = collectionGroup.Key.Type,
|
||||||
})
|
})
|
||||||
.ToList();
|
.ToList();
|
||||||
}
|
}
|
||||||
@@ -281,7 +284,8 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
c.Name,
|
c.Name,
|
||||||
c.CreationDate,
|
c.CreationDate,
|
||||||
c.RevisionDate,
|
c.RevisionDate,
|
||||||
c.ExternalId
|
c.ExternalId,
|
||||||
|
c.Type
|
||||||
} into collectionGroup
|
} into collectionGroup
|
||||||
select new CollectionDetails
|
select new CollectionDetails
|
||||||
{
|
{
|
||||||
@@ -294,6 +298,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
ReadOnly = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.ReadOnly))),
|
ReadOnly = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.ReadOnly))),
|
||||||
HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))),
|
HidePasswords = Convert.ToBoolean(collectionGroup.Min(c => Convert.ToInt32(c.HidePasswords))),
|
||||||
Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))),
|
Manage = Convert.ToBoolean(collectionGroup.Max(c => Convert.ToInt32(c.Manage))),
|
||||||
|
Type = collectionGroup.Key.Type,
|
||||||
}).ToListAsync();
|
}).ToListAsync();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -711,6 +716,7 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
private static async Task ReplaceCollectionGroupsAsync(DatabaseContext dbContext, Core.Entities.Collection collection, IEnumerable<CollectionAccessSelection> groups)
|
||||||
{
|
{
|
||||||
var existingCollectionGroups = await dbContext.CollectionGroups
|
var existingCollectionGroups = await dbContext.CollectionGroups
|
||||||
@@ -782,4 +788,88 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
|
|||||||
dbContext.CollectionUsers.RemoveRange(toDelete);
|
dbContext.CollectionUsers.RemoveRange(toDelete);
|
||||||
// SaveChangesAsync is expected to be called outside this method
|
// SaveChangesAsync is expected to be called outside this method
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Task CreateDefaultCollectionsAsync(Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName)
|
||||||
|
{
|
||||||
|
if (!affectedOrgUserIds.Any())
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
using var scope = ServiceScopeFactory.CreateScope();
|
||||||
|
var dbContext = GetDatabaseContext(scope);
|
||||||
|
|
||||||
|
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId);
|
||||||
|
|
||||||
|
var missingDefaultCollectionUserIds = affectedOrgUserIds.Except(orgUserIdWithDefaultCollection);
|
||||||
|
|
||||||
|
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
if (!collectionUsers.Any() || !collections.Any())
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await dbContext.BulkCopyAsync(collections);
|
||||||
|
await dbContext.BulkCopyAsync(collectionUsers);
|
||||||
|
|
||||||
|
await dbContext.SaveChangesAsync();
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId)
|
||||||
|
{
|
||||||
|
var results = await dbContext.OrganizationUsers
|
||||||
|
.Where(ou => ou.OrganizationId == organizationId)
|
||||||
|
.Join(
|
||||||
|
dbContext.CollectionUsers,
|
||||||
|
ou => ou.Id,
|
||||||
|
cu => cu.OrganizationUserId,
|
||||||
|
(ou, cu) => new { ou, cu }
|
||||||
|
)
|
||||||
|
.Join(
|
||||||
|
dbContext.Collections,
|
||||||
|
temp => temp.cu.CollectionId,
|
||||||
|
c => c.Id,
|
||||||
|
(temp, c) => new { temp.ou, Collection = c }
|
||||||
|
)
|
||||||
|
.Where(x => x.Collection.Type == CollectionType.DefaultUserCollection)
|
||||||
|
.Select(x => x.ou.Id)
|
||||||
|
.ToListAsync();
|
||||||
|
|
||||||
|
return results.ToHashSet();
|
||||||
|
}
|
||||||
|
|
||||||
|
private (List<CollectionUser> collectionUser, List<Collection> collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable<Guid> missingDefaultCollectionUserIds, string defaultCollectionName)
|
||||||
|
{
|
||||||
|
var collectionUsers = new List<CollectionUser>();
|
||||||
|
var collections = new List<Collection>();
|
||||||
|
|
||||||
|
foreach (var orgUserId in missingDefaultCollectionUserIds)
|
||||||
|
{
|
||||||
|
var collectionId = Guid.NewGuid();
|
||||||
|
|
||||||
|
collections.Add(new Collection
|
||||||
|
{
|
||||||
|
Id = collectionId,
|
||||||
|
OrganizationId = organizationId,
|
||||||
|
Name = defaultCollectionName,
|
||||||
|
CreationDate = DateTime.UtcNow,
|
||||||
|
RevisionDate = DateTime.UtcNow,
|
||||||
|
Type = CollectionType.DefaultUserCollection,
|
||||||
|
DefaultUserCollectionEmail = null
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
collectionUsers.Add(new CollectionUser
|
||||||
|
{
|
||||||
|
CollectionId = collectionId,
|
||||||
|
OrganizationUserId = orgUserId,
|
||||||
|
ReadOnly = false,
|
||||||
|
HidePasswords = false,
|
||||||
|
Manage = true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return (collectionUsers, collections);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,17 +47,18 @@ public class UserCollectionDetailsQuery : IQuery<CollectionDetails>
|
|||||||
((cu == null ? (Guid?)null : cu.CollectionId) != null || (cg == null ? (Guid?)null : cg.CollectionId) != null)
|
((cu == null ? (Guid?)null : cu.CollectionId) != null || (cg == null ? (Guid?)null : cg.CollectionId) != null)
|
||||||
select new { c, ou, o, cu, gu, g, cg };
|
select new { c, ou, o, cu, gu, g, cg };
|
||||||
|
|
||||||
return query.Select(x => new CollectionDetails
|
return query.Select(row => new CollectionDetails
|
||||||
{
|
{
|
||||||
Id = x.c.Id,
|
Id = row.c.Id,
|
||||||
OrganizationId = x.c.OrganizationId,
|
OrganizationId = row.c.OrganizationId,
|
||||||
Name = x.c.Name,
|
Name = row.c.Name,
|
||||||
ExternalId = x.c.ExternalId,
|
ExternalId = row.c.ExternalId,
|
||||||
CreationDate = x.c.CreationDate,
|
CreationDate = row.c.CreationDate,
|
||||||
RevisionDate = x.c.RevisionDate,
|
RevisionDate = row.c.RevisionDate,
|
||||||
ReadOnly = (bool?)x.cu.ReadOnly ?? (bool?)x.cg.ReadOnly ?? false,
|
ReadOnly = (bool?)row.cu.ReadOnly ?? (bool?)row.cg.ReadOnly ?? false,
|
||||||
HidePasswords = (bool?)x.cu.HidePasswords ?? (bool?)x.cg.HidePasswords ?? false,
|
HidePasswords = (bool?)row.cu.HidePasswords ?? (bool?)row.cg.HidePasswords ?? false,
|
||||||
Manage = (bool?)x.cu.Manage ?? (bool?)x.cg.Manage ?? false,
|
Manage = (bool?)row.cu.Manage ?? (bool?)row.cg.Manage ?? false,
|
||||||
|
Type = row.c.Type
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
using Bit.Core.AdminConsole.Entities;
|
||||||
|
using Bit.Core.Entities;
|
||||||
|
using Bit.Core.Enums;
|
||||||
|
using Bit.Core.Repositories;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
|
||||||
|
|
||||||
|
public class CreateDefaultCollectionsTests
|
||||||
|
{
|
||||||
|
[DatabaseTheory, DatabaseData]
|
||||||
|
public async Task CreateDefaultCollectionsAsync_ShouldCreateDefaultCollection_WhenUsersDoNotHaveDefaultCollection(
|
||||||
|
IOrganizationRepository organizationRepository,
|
||||||
|
IUserRepository userRepository,
|
||||||
|
IOrganizationUserRepository organizationUserRepository,
|
||||||
|
ICollectionRepository collectionRepository)
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||||
|
|
||||||
|
var resultOrganizationUsers = await Task.WhenAll(
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization),
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id);
|
||||||
|
var defaultCollectionName = $"default-name-{organization.Id}";
|
||||||
|
|
||||||
|
// Act
|
||||||
|
await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
|
||||||
|
|
||||||
|
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
|
||||||
|
}
|
||||||
|
|
||||||
|
[DatabaseTheory, DatabaseData]
|
||||||
|
public async Task CreateDefaultCollectionsAsync_ShouldUpsertCreateDefaultCollection_ForUsersWithAndWithoutDefaultCollectionsExist(
|
||||||
|
IOrganizationRepository organizationRepository,
|
||||||
|
IUserRepository userRepository,
|
||||||
|
IOrganizationUserRepository organizationUserRepository,
|
||||||
|
ICollectionRepository collectionRepository)
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||||
|
|
||||||
|
var arrangedOrganizationUsers = await Task.WhenAll(
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization),
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
|
||||||
|
);
|
||||||
|
|
||||||
|
var arrangedOrgUserIds = arrangedOrganizationUsers.Select(organizationUser => organizationUser.Id);
|
||||||
|
var defaultCollectionName = $"default-name-{organization.Id}";
|
||||||
|
|
||||||
|
|
||||||
|
await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, arrangedOrgUserIds, defaultCollectionName, arrangedOrganizationUsers);
|
||||||
|
|
||||||
|
var newOrganizationUsers = new List<OrganizationUser>()
|
||||||
|
{
|
||||||
|
await CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
|
||||||
|
};
|
||||||
|
|
||||||
|
var affectedOrgUsers = newOrganizationUsers.Concat(arrangedOrganizationUsers);
|
||||||
|
var affectedOrgUserIds = affectedOrgUsers.Select(organizationUser => organizationUser.Id);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id);
|
||||||
|
|
||||||
|
await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers);
|
||||||
|
}
|
||||||
|
|
||||||
|
[DatabaseTheory, DatabaseData]
|
||||||
|
public async Task CreateDefaultCollectionsAsync_ShouldNotCreateDefaultCollection_WhenUsersAlreadyHaveOne(
|
||||||
|
IOrganizationRepository organizationRepository,
|
||||||
|
IUserRepository userRepository,
|
||||||
|
IOrganizationUserRepository organizationUserRepository,
|
||||||
|
ICollectionRepository collectionRepository)
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var organization = await organizationRepository.CreateTestOrganizationAsync();
|
||||||
|
|
||||||
|
var resultOrganizationUsers = await Task.WhenAll(
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization),
|
||||||
|
CreateUserForOrgAsync(userRepository, organizationUserRepository, organization)
|
||||||
|
);
|
||||||
|
|
||||||
|
var affectedOrgUserIds = resultOrganizationUsers.Select(organizationUser => organizationUser.Id);
|
||||||
|
var defaultCollectionName = $"default-name-{organization.Id}";
|
||||||
|
|
||||||
|
|
||||||
|
await CreateUsersWithExistingDefaultCollectionsAsync(collectionRepository, organization.Id, affectedOrgUserIds, defaultCollectionName, resultOrganizationUsers);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
await collectionRepository.CreateDefaultCollectionsAsync(organization.Id, affectedOrgUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
|
||||||
|
|
||||||
|
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task CreateUsersWithExistingDefaultCollectionsAsync(ICollectionRepository collectionRepository,
|
||||||
|
Guid organizationId, IEnumerable<Guid> affectedOrgUserIds, string defaultCollectionName,
|
||||||
|
OrganizationUser[] resultOrganizationUsers)
|
||||||
|
{
|
||||||
|
await collectionRepository.CreateDefaultCollectionsAsync(organizationId, affectedOrgUserIds, defaultCollectionName);
|
||||||
|
|
||||||
|
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organizationId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task AssertAllUsersHaveOneDefaultCollectionAsync(ICollectionRepository collectionRepository,
|
||||||
|
IEnumerable<OrganizationUser> organizationUsers, Guid organizationId)
|
||||||
|
{
|
||||||
|
foreach (var organizationUser in organizationUsers)
|
||||||
|
{
|
||||||
|
var collectionDetails = await collectionRepository.GetManyByUserIdAsync(organizationUser!.UserId.Value);
|
||||||
|
var defaultCollection = collectionDetails
|
||||||
|
.SingleOrDefault(collectionDetail =>
|
||||||
|
collectionDetail.OrganizationId == organizationId
|
||||||
|
&& collectionDetail.Type == CollectionType.DefaultUserCollection);
|
||||||
|
|
||||||
|
Assert.NotNull(defaultCollection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task<OrganizationUser> CreateUserForOrgAsync(IUserRepository userRepository,
|
||||||
|
IOrganizationUserRepository organizationUserRepository, Organization organization)
|
||||||
|
{
|
||||||
|
|
||||||
|
var user = await userRepository.CreateTestUserAsync();
|
||||||
|
var orgUser = await organizationUserRepository.CreateTestOrganizationUserAsync(organization, user);
|
||||||
|
|
||||||
|
return orgUser;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task CleanupAsync(IOrganizationRepository organizationRepository,
|
||||||
|
IUserRepository userRepository,
|
||||||
|
Organization organization,
|
||||||
|
IEnumerable<OrganizationUser> organizationUsers)
|
||||||
|
{
|
||||||
|
await organizationRepository.DeleteAsync(organization);
|
||||||
|
|
||||||
|
await userRepository.DeleteManyAsync(
|
||||||
|
organizationUsers
|
||||||
|
.Where(organizationUser => organizationUser.UserId != null)
|
||||||
|
.Select(organizationUser => new User() { Id = organizationUser.UserId.Value })
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user