1
0
mirror of https://github.com/bitwarden/server synced 2026-01-26 14:23:21 +00:00

Remove redundant OrganizationId column; remove private read method used by bulk insert

This commit is contained in:
Thomas Rittson
2026-01-01 09:30:24 +10:00
parent f668a0ce0a
commit 7ea237f5d5
15 changed files with 79 additions and 138 deletions

View File

@@ -1,8 +1,7 @@
namespace Bit.Core.Entities;
namespace Bit.Core.Entities;
public class DefaultCollectionSemaphore
{
public Guid OrganizationId { get; set; }
public Guid OrganizationUserId { get; set; }
public DateTime CreationDate { get; set; } = DateTime.UtcNow;
}

View File

@@ -82,9 +82,15 @@ public interface ICollectionRepository : IRepository<Collection, Guid>
Task UpsertDefaultCollectionsBulkAsync(Guid organizationId, IEnumerable<Guid> organizationUserIds, string defaultCollectionName);
/// <summary>
/// Gets organization user IDs that have default collection semaphore entries for the specified organization.
/// Gets default collection semaphores for the given organizationUserIds.
/// If an organizationUserId is missing from the result set, they do not have a semaphore set.
/// </summary>
/// <param name="organizationId">The Organization ID.</param>
/// <param name="organizationUserIds">The organization User IDs to check semaphores for.</param>
/// <returns>Collection of organization user IDs that have default collection semaphores.</returns>
Task<IEnumerable<Guid>> GetDefaultCollectionSemaphoresAsync(Guid organizationId);
/// <remarks>
/// The semaphore table is used to ensure that an organizationUser can only have 1 default collection.
/// (That is, a user may only have 1 default collection per organization.)
/// If a semaphore is returned, that user already has a default collection for that organization.
/// </remarks>
Task<HashSet<Guid>> GetDefaultCollectionSemaphoresAsync(IEnumerable<Guid> organizationUserIds);
}

View File

@@ -391,15 +391,15 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
return;
}
var orgUserIdWithDefaultCollection = await GetDefaultCollectionSemaphoresAsync(organizationUserIds);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
await using var connection = new SqlConnection(ConnectionString);
connection.Open();
await using var transaction = connection.BeginTransaction();
try
{
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(connection, transaction, organizationId);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
@@ -412,7 +412,6 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
var now = DateTime.UtcNow;
var semaphores = collectionUsers.Select(c => new DefaultCollectionSemaphore
{
OrganizationId = organizationId,
OrganizationUserId = c.OrganizationUserId,
CreationDate = now
}).ToList();
@@ -430,37 +429,16 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
}
}
public async Task<IEnumerable<Guid>> GetDefaultCollectionSemaphoresAsync(Guid organizationId)
public async Task<HashSet<Guid>> GetDefaultCollectionSemaphoresAsync(IEnumerable<Guid> organizationUserIds)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<Guid>(
"[dbo].[DefaultCollectionSemaphore_ReadByOrganizationId]",
new { OrganizationId = organizationId },
commandType: CommandType.StoredProcedure);
await using var connection = new SqlConnection(ConnectionString);
return results.ToList();
}
}
var results = await connection.QueryAsync<Guid>(
"[dbo].[DefaultCollectionSemaphore_ReadByOrganizationUserIds]",
new { OrganizationUserIds = organizationUserIds.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(SqlConnection connection, SqlTransaction transaction, Guid organizationId)
{
const string sql = @"
SELECT
OrganizationUserId
FROM
[DefaultCollectionSemaphore] dcs
WHERE
OrganizationId = @OrganizationId
";
var organizationUserIds = await connection.QueryAsync<Guid>(
sql,
new { OrganizationId = organizationId, CollectionType = CollectionType.DefaultUserCollection },
transaction: transaction
);
return organizationUserIds.ToHashSet();
return results.ToHashSet();
}
private (List<CollectionUser> collectionUser, List<Collection> collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable<Guid> missingDefaultCollectionUserIds, string defaultCollectionName)
@@ -506,8 +484,7 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
// Sort by composite key to reduce deadlocks
var sortedSemaphores = semaphores
.OrderBy(s => s.OrganizationId)
.ThenBy(s => s.OrganizationUserId)
.OrderBy(s => s.OrganizationUserId)
.ToList();
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.KeepIdentity & SqlBulkCopyOptions.CheckConstraints, transaction);
@@ -518,8 +495,6 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
var dataTable = new DataTable("DefaultCollectionSemaphoreDataTable");
var organizationIdColumn = new DataColumn(nameof(DefaultCollectionSemaphore.OrganizationId), typeof(Guid));
dataTable.Columns.Add(organizationIdColumn);
var organizationUserIdColumn = new DataColumn(nameof(DefaultCollectionSemaphore.OrganizationUserId), typeof(Guid));
dataTable.Columns.Add(organizationUserIdColumn);
var creationDateColumn = new DataColumn(nameof(DefaultCollectionSemaphore.CreationDate), typeof(DateTime));
@@ -530,15 +505,13 @@ public class CollectionRepository : Repository<Collection, Guid>, ICollectionRep
bulkCopy.ColumnMappings.Add(col.ColumnName, col.ColumnName);
}
var keys = new DataColumn[2];
keys[0] = organizationIdColumn;
keys[1] = organizationUserIdColumn;
var keys = new DataColumn[1];
keys[0] = organizationUserIdColumn;
dataTable.PrimaryKey = keys;
foreach (var semaphore in sortedSemaphores)
{
var row = dataTable.NewRow();
row[organizationIdColumn] = semaphore.OrganizationId;
row[organizationUserIdColumn] = semaphore.OrganizationUserId;
row[creationDateColumn] = semaphore.CreationDate;
dataTable.Rows.Add(row);

View File

@@ -1,4 +1,4 @@
using Bit.Infrastructure.EntityFramework.AdminConsole.Models;
using Bit.Infrastructure.EntityFramework.AdminConsole.Models;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
@@ -9,15 +9,7 @@ public class DefaultCollectionSemaphoreEntityTypeConfiguration : IEntityTypeConf
public void Configure(EntityTypeBuilder<DefaultCollectionSemaphore> builder)
{
builder
.HasKey(dcs => new { dcs.OrganizationId, dcs.OrganizationUserId });
// Cascade behavior: Organization -> OrganizationUser (CASCADE) -> DefaultCollectionSemaphore (CASCADE)
// Organization FK uses NoAction to avoid competing cascade paths
builder
.HasOne(dcs => dcs.Organization)
.WithMany()
.HasForeignKey(dcs => dcs.OrganizationId)
.OnDelete(DeleteBehavior.NoAction);
.HasKey(dcs => new { dcs.OrganizationUserId });
// OrganizationUser FK cascades deletions to ensure automatic cleanup
builder

View File

@@ -1,11 +1,10 @@
using AutoMapper;
using AutoMapper;
using Bit.Infrastructure.EntityFramework.Models;
namespace Bit.Infrastructure.EntityFramework.AdminConsole.Models;
public class DefaultCollectionSemaphore : Core.Entities.DefaultCollectionSemaphore
{
public virtual Organization? Organization { get; set; }
public virtual OrganizationUser? OrganizationUser { get; set; }
}
@@ -14,7 +13,6 @@ public class DefaultCollectionSemaphoreMapperProfile : Profile
public DefaultCollectionSemaphoreMapperProfile()
{
CreateMap<Core.Entities.DefaultCollectionSemaphore, DefaultCollectionSemaphore>()
.ForMember(dcs => dcs.Organization, opt => opt.Ignore())
.ForMember(dcs => dcs.OrganizationUser, opt => opt.Ignore())
.ReverseMap();
}

View File

@@ -803,12 +803,12 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
return;
}
var orgUserIdWithDefaultCollection = await GetDefaultCollectionSemaphoresAsync(organizationUserIds);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var orgUserIdWithDefaultCollection = await GetOrgUserIdsWithDefaultCollectionAsync(dbContext, organizationId);
var missingDefaultCollectionUserIds = organizationUserIds.Except(orgUserIdWithDefaultCollection);
var (collectionUsers, collections) = BuildDefaultCollectionForUsers(organizationId, missingDefaultCollectionUserIds, defaultCollectionName);
if (!collectionUsers.Any() || !collections.Any())
@@ -821,7 +821,6 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
var now = DateTime.UtcNow;
var semaphores = collectionUsers.Select(c => new DefaultCollectionSemaphore
{
OrganizationId = organizationId,
OrganizationUserId = c.OrganizationUserId,
CreationDate = now
}).ToList();
@@ -839,27 +838,19 @@ public class CollectionRepository : Repository<Core.Entities.Collection, Collect
await CreateDefaultCollectionsAsync(organizationId, organizationUserIds, defaultCollectionName);
}
public async Task<IEnumerable<Guid>> GetDefaultCollectionSemaphoresAsync(Guid organizationId)
public async Task<HashSet<Guid>> GetDefaultCollectionSemaphoresAsync(IEnumerable<Guid> organizationUserIds)
{
var organizationUserIdsHashSet = organizationUserIds.ToHashSet();
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);
var organizationUserIds = await dbContext.DefaultCollectionSemaphores
.Where(s => s.OrganizationId == organizationId)
var result = await dbContext.DefaultCollectionSemaphores
.Where(s => organizationUserIdsHashSet.Contains(s.OrganizationUserId))
.Select(s => s.OrganizationUserId)
.ToListAsync();
return organizationUserIds;
}
private async Task<HashSet<Guid>> GetOrgUserIdsWithDefaultCollectionAsync(DatabaseContext dbContext, Guid organizationId)
{
var results = await dbContext.DefaultCollectionSemaphores
.Where(ou => ou.OrganizationId == organizationId)
.Select(x => x.OrganizationUserId)
.ToListAsync();
return results.ToHashSet();
return result.ToHashSet();
}
private (List<CollectionUser> collectionUser, List<Collection> collection) BuildDefaultCollectionForUsers(Guid organizationId, IEnumerable<Guid> missingDefaultCollectionUserIds, string defaultCollectionName)

View File

@@ -35,12 +35,10 @@ BEGIN
-- If this fails due to duplicate key, the entire transaction will be rolled back
INSERT INTO [dbo].[DefaultCollectionSemaphore]
(
[OrganizationId],
[OrganizationUserId],
[CreationDate]
)
SELECT
@OrganizationId,
ou.[OrganizationUserId],
GETUTCDATE()
FROM

View File

@@ -3,10 +3,9 @@
-- OrganizationId FK has NO ACTION to avoid competing cascade paths
CREATE TABLE [dbo].[DefaultCollectionSemaphore]
(
[OrganizationId] UNIQUEIDENTIFIER NOT NULL,
[OrganizationUserId] UNIQUEIDENTIFIER NOT NULL,
[CreationDate] DATETIME2 (7) NOT NULL,
CONSTRAINT [PK_DefaultCollectionSemaphore] PRIMARY KEY CLUSTERED ([OrganizationId] ASC, [OrganizationUserId] ASC),
CONSTRAINT [FK_DefaultCollectionSemaphore_Organization] FOREIGN KEY ([OrganizationId]) REFERENCES [dbo].[Organization] ([Id]), -- NO ACTION to avoid competing cascades
CONSTRAINT [FK_DefaultCollectionSemaphore_OrganizationUser] FOREIGN KEY ([OrganizationUserId]) REFERENCES [dbo].[OrganizationUser] ([Id]) ON DELETE CASCADE -- Cascades from OrganizationUser deletion
CONSTRAINT [PK_DefaultCollectionSemaphore] PRIMARY KEY CLUSTERED ([OrganizationUserId] ASC),
CONSTRAINT [FK_DefaultCollectionSemaphore_OrganizationUser] FOREIGN KEY ([OrganizationUserId])
REFERENCES [dbo].[OrganizationUser] ([Id]) ON DELETE CASCADE
);

View File

@@ -1,13 +0,0 @@
CREATE PROCEDURE [dbo].[DefaultCollectionSemaphore_ReadByOrganizationId]
@OrganizationId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON
SELECT
[OrganizationUserId]
FROM
[dbo].[DefaultCollectionSemaphore]
WHERE
[OrganizationId] = @OrganizationId
END

View File

@@ -0,0 +1,13 @@
CREATE PROCEDURE [dbo].[DefaultCollectionSemaphore_ReadByOrganizationUserIds]
@OrganizationUserIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
SELECT
[OrganizationUserId]
FROM
[dbo].[DefaultCollectionSemaphore] DCS
INNER JOIN
@OrganizationUserIds OU ON [OU].[Id] = [DCS].[OrganizationUserId]
END

View File

@@ -1,4 +1,4 @@
using Bit.Core.Enums;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Xunit;
@@ -40,8 +40,8 @@ public class CreateDefaultCollectionsTests
Assert.All(defaultCollections, c => Assert.Equal("My Items", c.Item1.Name));
Assert.All(defaultCollections, c => Assert.Equal(organization.Id, c.Item1.OrganizationId));
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
Assert.Equal([orgUser1.Id, orgUser2.Id], semaphores.ToHashSet());
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser1.Id, orgUser2.Id]);
Assert.Equal([orgUser1.Id, orgUser2.Id], semaphores);
// Verify each user has exactly 1 collection with correct permissions
var orgUser1Collection = Assert.Single(defaultCollections,
@@ -100,7 +100,7 @@ public class CreateDefaultCollectionsTests
Assert.Single(defaultCollections);
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser.Id]);
Assert.Equal([orgUser.Id], semaphores);
var access = await collectionRepository.GetManyUsersByIdAsync(defaultCollections.Single().Id);

View File

@@ -1,10 +1,4 @@
using Bit.Core.AdminConsole.Entities;
using Bit.Core.Billing.Enums;
using Bit.Core.Entities;
using Bit.Core.Enums;
using Bit.Core.Repositories;
using Bit.Infrastructure.EntityFramework.Repositories;
using Microsoft.EntityFrameworkCore;
using Bit.Core.Repositories;
using Xunit;
namespace Bit.Infrastructure.IntegrationTest.AdminConsole.Repositories.CollectionRepository;
@@ -32,14 +26,14 @@ public class DefaultCollectionSemaphoreTests
"My Items");
// Verify semaphore exists
var semaphoreBefore = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
var semaphoreBefore = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser.Id]);
Assert.Single(semaphoreBefore, s => s == orgUser.Id);
// Act - Delete organization user
await organizationUserRepository.DeleteAsync(orgUser);
// Assert - Semaphore should be cascade deleted
var semaphoreAfter = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
var semaphoreAfter = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser.Id]);
Assert.Empty(semaphoreAfter);
}
@@ -65,14 +59,14 @@ public class DefaultCollectionSemaphoreTests
"My Items");
// Verify semaphore exists
var semaphoreBefore = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
var semaphoreBefore = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser.Id]);
Assert.Single(semaphoreBefore, s => s == orgUser.Id);
// Act - Delete organization (which cascades to OrganizationUser, which cascades to semaphore)
await organizationRepository.DeleteAsync(organization);
// Assert - Semaphore should be cascade deleted via OrganizationUser
var semaphoreAfter = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organization.Id);
var semaphoreAfter = await collectionRepository.GetDefaultCollectionSemaphoresAsync([orgUser.Id]);
Assert.Empty(semaphoreAfter);
}
}

View File

@@ -31,7 +31,7 @@ public class UpsertDefaultCollectionsBulkTests
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds);
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
}
@@ -69,7 +69,7 @@ public class UpsertDefaultCollectionsBulkTests
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, arrangedOrganizationUsers, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds);
await CleanupAsync(organizationRepository, userRepository, organization, affectedOrgUsers);
}
@@ -99,7 +99,7 @@ public class UpsertDefaultCollectionsBulkTests
// Assert
await AssertAllUsersHaveOneDefaultCollectionAsync(collectionRepository, resultOrganizationUsers, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds, organization.Id);
await AssertSempahoresCreatedAsync(collectionRepository, affectedOrgUserIds);
await CleanupAsync(organizationRepository, userRepository, organization, resultOrganizationUsers);
}
@@ -138,10 +138,11 @@ public class UpsertDefaultCollectionsBulkTests
}
private static async Task AssertSempahoresCreatedAsync(ICollectionRepository collectionRepository,
IEnumerable<Guid> organizationUserIds, Guid organizationId)
IEnumerable<Guid> organizationUserIds)
{
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organizationId);
Assert.Equal(organizationUserIds.ToHashSet(), semaphores.ToHashSet());
var organizationUserIdHashSet = organizationUserIds.ToHashSet();
var semaphores = await collectionRepository.GetDefaultCollectionSemaphoresAsync(organizationUserIdHashSet);
Assert.Equal(organizationUserIdHashSet, semaphores);
}
private static async Task CleanupAsync(IOrganizationRepository organizationRepository,

View File

@@ -5,26 +5,18 @@ IF OBJECT_ID('[dbo].[DefaultCollectionSemaphore]') IS NULL
BEGIN
CREATE TABLE [dbo].[DefaultCollectionSemaphore]
(
[OrganizationId] UNIQUEIDENTIFIER NOT NULL,
[OrganizationUserId] UNIQUEIDENTIFIER NOT NULL,
[CreationDate] DATETIME2(7) NOT NULL,
CONSTRAINT [PK_DefaultCollectionSemaphore] PRIMARY KEY CLUSTERED
(
[OrganizationId] ASC,
[OrganizationUserId] ASC
),
CONSTRAINT [FK_DefaultCollectionSemaphore_Organization] FOREIGN KEY ([OrganizationId])
REFERENCES [dbo].[Organization] ([Id]), -- NO ACTION to avoid competing cascades
CONSTRAINT [FK_DefaultCollectionSemaphore_OrganizationUser] FOREIGN KEY ([OrganizationUserId])
REFERENCES [dbo].[OrganizationUser] ([Id])
ON DELETE CASCADE -- Cascades from OrganizationUser deletion
[OrganizationUserId] UNIQUEIDENTIFIER NOT NULL,
[CreationDate] DATETIME2 (7) NOT NULL,
CONSTRAINT [PK_DefaultCollectionSemaphore] PRIMARY KEY CLUSTERED ([OrganizationUserId] ASC),
CONSTRAINT [FK_DefaultCollectionSemaphore_OrganizationUser] FOREIGN KEY ([OrganizationUserId])
REFERENCES [dbo].[OrganizationUser] ([Id]) ON DELETE CASCADE
);
END
GO
-- Create stored procedure to read semaphores by organization
CREATE OR ALTER PROCEDURE [dbo].[DefaultCollectionSemaphore_ReadByOrganizationId]
@OrganizationId UNIQUEIDENTIFIER
-- Create stored procedure to read semaphores by OrganizationUserId
CREATE OR ALTER PROCEDURE [dbo].[DefaultCollectionSemaphore_ReadByOrganizationUserIds]
@OrganizationUserIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON
@@ -32,8 +24,8 @@ BEGIN
SELECT
[OrganizationUserId]
FROM
[dbo].[DefaultCollectionSemaphore]
WHERE
[OrganizationId] = @OrganizationId
[dbo].[DefaultCollectionSemaphore] DCS
INNER JOIN
@OrganizationUserIds OU ON [OU].[Id] = [DCS].[OrganizationUserId]
END
GO

View File

@@ -31,12 +31,10 @@ BEGIN
-- If this fails due to duplicate key, the entire transaction will be rolled back
INSERT INTO [dbo].[DefaultCollectionSemaphore]
(
[OrganizationId],
[OrganizationUserId],
[CreationDate]
)
SELECT
@OrganizationId,
ou.[OrganizationUserId],
GETUTCDATE()
FROM