mirror of
https://github.com/bitwarden/server
synced 2025-12-06 00:03:34 +00:00
when ciphers are soft deleted, complete any associated security tasks (#6492)
This commit is contained in:
@@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository<SecurityTask, Guid>
|
||||
/// <param name="organizationId">The id of the organization</param>
|
||||
/// <returns>A collection of security task metrics</returns>
|
||||
Task<SecurityTaskMetrics> GetTaskMetricsAsync(Guid organizationId);
|
||||
|
||||
/// <summary>
|
||||
/// Marks all tasks associated with the respective ciphers as complete.
|
||||
/// </summary>
|
||||
/// <param name="cipherIds">Collection of cipher IDs</param>
|
||||
Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds);
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ public class CipherService : ICipherService
|
||||
private readonly IOrganizationRepository _organizationRepository;
|
||||
private readonly IOrganizationUserRepository _organizationUserRepository;
|
||||
private readonly ICollectionCipherRepository _collectionCipherRepository;
|
||||
private readonly ISecurityTaskRepository _securityTaskRepository;
|
||||
private readonly IPushNotificationService _pushService;
|
||||
private readonly IAttachmentStorageService _attachmentStorageService;
|
||||
private readonly IEventService _eventService;
|
||||
@@ -53,6 +54,7 @@ public class CipherService : ICipherService
|
||||
IOrganizationRepository organizationRepository,
|
||||
IOrganizationUserRepository organizationUserRepository,
|
||||
ICollectionCipherRepository collectionCipherRepository,
|
||||
ISecurityTaskRepository securityTaskRepository,
|
||||
IPushNotificationService pushService,
|
||||
IAttachmentStorageService attachmentStorageService,
|
||||
IEventService eventService,
|
||||
@@ -71,6 +73,7 @@ public class CipherService : ICipherService
|
||||
_organizationRepository = organizationRepository;
|
||||
_organizationUserRepository = organizationUserRepository;
|
||||
_collectionCipherRepository = collectionCipherRepository;
|
||||
_securityTaskRepository = securityTaskRepository;
|
||||
_pushService = pushService;
|
||||
_attachmentStorageService = attachmentStorageService;
|
||||
_eventService = eventService;
|
||||
@@ -724,6 +727,7 @@ public class CipherService : ICipherService
|
||||
cipherDetails.ArchivedDate = null;
|
||||
}
|
||||
|
||||
await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]);
|
||||
await _cipherRepository.UpsertAsync(cipherDetails);
|
||||
await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted);
|
||||
|
||||
@@ -750,6 +754,8 @@ public class CipherService : ICipherService
|
||||
await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId);
|
||||
}
|
||||
|
||||
await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id));
|
||||
|
||||
var events = deletingCiphers.Select(c =>
|
||||
new Tuple<Cipher, EventType, DateTime?>(c, EventType.Cipher_SoftDeleted, null));
|
||||
foreach (var eventsBatch in events.Chunk(100))
|
||||
|
||||
@@ -85,4 +85,19 @@ public class SecurityTaskRepository : Repository<SecurityTask, Guid>, ISecurityT
|
||||
|
||||
return tasksList;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
|
||||
{
|
||||
if (!cipherIds.Any())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await using var connection = new SqlConnection(ConnectionString);
|
||||
await connection.ExecuteAsync(
|
||||
$"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]",
|
||||
new { CipherIds = cipherIds.ToGuidIdArrayTVP() },
|
||||
commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,4 +96,24 @@ public class SecurityTaskRepository : Repository<Core.Vault.Entities.SecurityTas
|
||||
|
||||
return metrics ?? new Core.Vault.Entities.SecurityTaskMetrics(0, 0);
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
|
||||
{
|
||||
if (!cipherIds.Any())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
using var scope = ServiceScopeFactory.CreateScope();
|
||||
var dbContext = GetDatabaseContext(scope);
|
||||
|
||||
var cipherIdsList = cipherIds.ToList();
|
||||
|
||||
await dbContext.SecurityTasks
|
||||
.Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed)
|
||||
.ExecuteUpdateAsync(st => st
|
||||
.SetProperty(s => s.Status, SecurityTaskStatus.Completed)
|
||||
.SetProperty(s => s.RevisionDate, DateTime.UtcNow));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
|
||||
@CipherIds AS [dbo].[GuidIdArray] READONLY
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
UPDATE
|
||||
[dbo].[SecurityTask]
|
||||
SET
|
||||
[Status] = 1, -- completed
|
||||
[RevisionDate] = SYSUTCDATETIME()
|
||||
WHERE
|
||||
[CipherId] IN (SELECT [Id] FROM @CipherIds)
|
||||
AND [Status] <> 1 -- Not already completed
|
||||
END
|
||||
@@ -2286,6 +2286,63 @@ public class CipherServiceTests
|
||||
.PushSyncCiphersAsync(deletingUserId);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SoftDeleteAsync_CallsMarkAsCompleteByCipherIds(
|
||||
Guid deletingUserId, CipherDetails cipherDetails, SutProvider<CipherService> sutProvider)
|
||||
{
|
||||
cipherDetails.UserId = deletingUserId;
|
||||
cipherDetails.OrganizationId = null;
|
||||
cipherDetails.DeletedDate = null;
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.GetUserByIdAsync(deletingUserId)
|
||||
.Returns(new User
|
||||
{
|
||||
Id = deletingUserId,
|
||||
});
|
||||
|
||||
await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId);
|
||||
|
||||
await sutProvider.GetDependency<ISecurityTaskRepository>()
|
||||
.Received(1)
|
||||
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
|
||||
ids.Count() == 1 && ids.First() == cipherDetails.Id));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[BitAutoData]
|
||||
public async Task SoftDeleteManyAsync_CallsMarkAsCompleteByCipherIds(
|
||||
Guid deletingUserId, List<CipherDetails> ciphers, SutProvider<CipherService> sutProvider)
|
||||
{
|
||||
var cipherIds = ciphers.Select(c => c.Id).ToArray();
|
||||
|
||||
foreach (var cipher in ciphers)
|
||||
{
|
||||
cipher.UserId = deletingUserId;
|
||||
cipher.OrganizationId = null;
|
||||
cipher.Edit = true;
|
||||
cipher.DeletedDate = null;
|
||||
}
|
||||
|
||||
sutProvider.GetDependency<IUserService>()
|
||||
.GetUserByIdAsync(deletingUserId)
|
||||
.Returns(new User
|
||||
{
|
||||
Id = deletingUserId,
|
||||
});
|
||||
sutProvider.GetDependency<ICipherRepository>()
|
||||
.GetManyByUserIdAsync(deletingUserId)
|
||||
.Returns(ciphers);
|
||||
|
||||
await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false);
|
||||
|
||||
await sutProvider.GetDependency<ISecurityTaskRepository>()
|
||||
.Received(1)
|
||||
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
|
||||
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
|
||||
}
|
||||
|
||||
private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
|
||||
{
|
||||
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);
|
||||
|
||||
@@ -345,4 +345,110 @@ public class SecurityTaskRepositoryTests
|
||||
Assert.Equal(0, metrics.CompletedTasks);
|
||||
Assert.Equal(0, metrics.TotalTasks);
|
||||
}
|
||||
|
||||
[DatabaseTheory, DatabaseData]
|
||||
public async Task MarkAsCompleteByCipherIds_MarksPendingTasksAsCompleted(
|
||||
IOrganizationRepository organizationRepository,
|
||||
ICipherRepository cipherRepository,
|
||||
ISecurityTaskRepository securityTaskRepository)
|
||||
{
|
||||
var organization = await organizationRepository.CreateAsync(new Organization
|
||||
{
|
||||
Name = "Test Org",
|
||||
PlanType = PlanType.EnterpriseAnnually,
|
||||
Plan = "Test Plan",
|
||||
BillingEmail = "billing@email.com"
|
||||
});
|
||||
|
||||
var cipher1 = await cipherRepository.CreateAsync(new Cipher
|
||||
{
|
||||
Type = CipherType.Login,
|
||||
OrganizationId = organization.Id,
|
||||
Data = "",
|
||||
});
|
||||
|
||||
var cipher2 = await cipherRepository.CreateAsync(new Cipher
|
||||
{
|
||||
Type = CipherType.Login,
|
||||
OrganizationId = organization.Id,
|
||||
Data = "",
|
||||
});
|
||||
|
||||
var task1 = await securityTaskRepository.CreateAsync(new SecurityTask
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
CipherId = cipher1.Id,
|
||||
Status = SecurityTaskStatus.Pending,
|
||||
Type = SecurityTaskType.UpdateAtRiskCredential,
|
||||
});
|
||||
|
||||
var task2 = await securityTaskRepository.CreateAsync(new SecurityTask
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
CipherId = cipher2.Id,
|
||||
Status = SecurityTaskStatus.Pending,
|
||||
Type = SecurityTaskType.UpdateAtRiskCredential,
|
||||
});
|
||||
|
||||
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id, cipher2.Id]);
|
||||
|
||||
var updatedTask1 = await securityTaskRepository.GetByIdAsync(task1.Id);
|
||||
var updatedTask2 = await securityTaskRepository.GetByIdAsync(task2.Id);
|
||||
|
||||
Assert.Equal(SecurityTaskStatus.Completed, updatedTask1.Status);
|
||||
Assert.Equal(SecurityTaskStatus.Completed, updatedTask2.Status);
|
||||
}
|
||||
|
||||
[DatabaseTheory, DatabaseData]
|
||||
public async Task MarkAsCompleteByCipherIds_OnlyUpdatesSpecifiedCiphers(
|
||||
IOrganizationRepository organizationRepository,
|
||||
ICipherRepository cipherRepository,
|
||||
ISecurityTaskRepository securityTaskRepository)
|
||||
{
|
||||
var organization = await organizationRepository.CreateAsync(new Organization
|
||||
{
|
||||
Name = "Test Org",
|
||||
PlanType = PlanType.EnterpriseAnnually,
|
||||
Plan = "Test Plan",
|
||||
BillingEmail = "billing@email.com"
|
||||
});
|
||||
|
||||
var cipher1 = await cipherRepository.CreateAsync(new Cipher
|
||||
{
|
||||
Type = CipherType.Login,
|
||||
OrganizationId = organization.Id,
|
||||
Data = "",
|
||||
});
|
||||
|
||||
var cipher2 = await cipherRepository.CreateAsync(new Cipher
|
||||
{
|
||||
Type = CipherType.Login,
|
||||
OrganizationId = organization.Id,
|
||||
Data = "",
|
||||
});
|
||||
|
||||
var taskToUpdate = await securityTaskRepository.CreateAsync(new SecurityTask
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
CipherId = cipher1.Id,
|
||||
Status = SecurityTaskStatus.Pending,
|
||||
Type = SecurityTaskType.UpdateAtRiskCredential,
|
||||
});
|
||||
|
||||
var taskToKeep = await securityTaskRepository.CreateAsync(new SecurityTask
|
||||
{
|
||||
OrganizationId = organization.Id,
|
||||
CipherId = cipher2.Id,
|
||||
Status = SecurityTaskStatus.Pending,
|
||||
Type = SecurityTaskType.UpdateAtRiskCredential,
|
||||
});
|
||||
|
||||
await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id]);
|
||||
|
||||
var updatedTask = await securityTaskRepository.GetByIdAsync(taskToUpdate.Id);
|
||||
var unchangedTask = await securityTaskRepository.GetByIdAsync(taskToKeep.Id);
|
||||
|
||||
Assert.Equal(SecurityTaskStatus.Completed, updatedTask.Status);
|
||||
Assert.Equal(SecurityTaskStatus.Pending, unchangedTask.Status);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
CREATE OR ALTER PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
|
||||
@CipherIds AS [dbo].[GuidIdArray] READONLY
|
||||
AS
|
||||
BEGIN
|
||||
SET NOCOUNT ON
|
||||
|
||||
UPDATE
|
||||
[dbo].[SecurityTask]
|
||||
SET
|
||||
[Status] = 1, -- Completed
|
||||
[RevisionDate] = SYSUTCDATETIME()
|
||||
WHERE
|
||||
[CipherId] IN (SELECT [Id] FROM @CipherIds)
|
||||
AND [Status] <> 1 -- Not already completed
|
||||
END
|
||||
Reference in New Issue
Block a user