From 76de64263c94c8331586aa898e6462afb214e5b8 Mon Sep 17 00:00:00 2001 From: Nik Gilmore Date: Wed, 22 Oct 2025 16:19:43 -0700 Subject: [PATCH 1/3] [PM-22992] Check cipher revision date when handling attachments (#6451) * Add lastKnownRevisionDate to Attachment functions * Add lastKnownRevisionDate to attachment endpoints * Change lastKnownCipherRevisionDate to lastKnownRevisionDate for consistency * Add tests for RevisionDate checks in Attachment endpoints * Improve validation on lastKnownRevisionDate * Harden datetime parsing * Rename ValidateCipherLastKnownRevisionDate - removed 'Async' suffix * Cleanup and address PR feedback --- .../Vault/Controllers/CiphersController.cs | 36 ++- .../Models/Request/AttachmentRequestModel.cs | 5 + src/Core/Vault/Services/ICipherService.cs | 8 +- .../Services/Implementations/CipherService.cs | 20 +- .../Vault/Services/CipherServiceTests.cs | 236 ++++++++++++++++++ 5 files changed, 288 insertions(+), 17 deletions(-) diff --git a/src/Api/Vault/Controllers/CiphersController.cs b/src/Api/Vault/Controllers/CiphersController.cs index fe3069d8c7..46d8332926 100644 --- a/src/Api/Vault/Controllers/CiphersController.cs +++ b/src/Api/Vault/Controllers/CiphersController.cs @@ -1,6 +1,7 @@ // FIXME: Update this file to be null safe and then delete the line below #nullable disable +using System.Globalization; using System.Text.Json; using Azure.Messaging.EventGrid; using Bit.Api.Auth.Models.Request.Accounts; @@ -1366,7 +1367,7 @@ public class CiphersController : Controller } var (attachmentId, uploadUrl) = await _cipherService.CreateAttachmentForDelayedUploadAsync(cipher, - request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id); + request.Key, request.FileName, request.FileSize, request.AdminRequest, user.Id, request.LastKnownRevisionDate); return new AttachmentUploadDataResponseModel { AttachmentId = attachmentId, @@ -1419,9 +1420,11 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream) => { - await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData); + await _cipherService.UploadFileForExistingAttachmentAsync(stream, cipher, attachmentData, lastKnownRevisionDate); }); } @@ -1440,10 +1443,12 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), user.Id); + Request.ContentLength.GetValueOrDefault(0), user.Id, false, lastKnownRevisionDate); }); return new CipherResponseModel( @@ -1469,10 +1474,13 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); + await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), userId, true); + Request.ContentLength.GetValueOrDefault(0), userId, true, lastKnownRevisionDate); }); return new CipherMiniResponseModel(cipher, _globalSettings, cipher.OrganizationUseTotp); @@ -1515,10 +1523,13 @@ public class CiphersController : Controller throw new NotFoundException(); } + // Extract lastKnownRevisionDate from form data if present + DateTime? lastKnownRevisionDate = GetLastKnownRevisionDateFromForm(); + await Request.GetFileAsync(async (stream, fileName, key) => { await _cipherService.CreateAttachmentShareAsync(cipher, stream, fileName, key, - Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId); + Request.ContentLength.GetValueOrDefault(0), attachmentId, organizationId, lastKnownRevisionDate); }); } @@ -1630,4 +1641,19 @@ public class CiphersController : Controller { return await _cipherRepository.GetByIdAsync(cipherId, userId); } + + private DateTime? GetLastKnownRevisionDateFromForm() + { + DateTime? lastKnownRevisionDate = null; + if (Request.Form.TryGetValue("lastKnownRevisionDate", out var dateValue)) + { + if (!DateTime.TryParse(dateValue, CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out var parsedDate)) + { + throw new BadRequestException("Invalid lastKnownRevisionDate format."); + } + lastKnownRevisionDate = parsedDate; + } + + return lastKnownRevisionDate; + } } diff --git a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs index 96c66c6044..eef70bf4e4 100644 --- a/src/Api/Vault/Models/Request/AttachmentRequestModel.cs +++ b/src/Api/Vault/Models/Request/AttachmentRequestModel.cs @@ -9,4 +9,9 @@ public class AttachmentRequestModel public string FileName { get; set; } public long FileSize { get; set; } public bool AdminRequest { get; set; } = false; + + /// + /// The last known revision date of the Cipher that this attachment belongs to. + /// + public DateTime? LastKnownRevisionDate { get; set; } } diff --git a/src/Core/Vault/Services/ICipherService.cs b/src/Core/Vault/Services/ICipherService.cs index ffd79e9381..110d4b6ea4 100644 --- a/src/Core/Vault/Services/ICipherService.cs +++ b/src/Core/Vault/Services/ICipherService.cs @@ -13,11 +13,11 @@ public interface ICipherService Task SaveDetailsAsync(CipherDetails cipher, Guid savingUserId, DateTime? lastKnownRevisionDate, IEnumerable collectionIds = null, bool skipPermissionCheck = false); Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId); + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false); + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null); Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, long requestLength, - string attachmentId, Guid organizationShareId); + string attachmentId, Guid organizationShareId, DateTime? lastKnownRevisionDate = null); Task DeleteAsync(CipherDetails cipherDetails, Guid deletingUserId, bool orgAdmin = false); Task DeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task DeleteAttachmentAsync(Cipher cipher, string attachmentId, Guid deletingUserId, bool orgAdmin = false); @@ -34,7 +34,7 @@ public interface ICipherService Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deletingUserId, Guid? organizationId = null, bool orgAdmin = false); Task RestoreAsync(CipherDetails cipherDetails, Guid restoringUserId, bool orgAdmin = false); Task> RestoreManyAsync(IEnumerable cipherIds, Guid restoringUserId, Guid? organizationId = null, bool orgAdmin = false); - Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId); + Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachmentId, DateTime? lastKnownRevisionDate = null); Task GetAttachmentDownloadDataAsync(Cipher cipher, string attachmentId); Task ValidateCipherAttachmentFile(Cipher cipher, CipherAttachment.MetaData attachmentData); Task ValidateBulkCollectionAssignmentAsync(IEnumerable collectionIds, IEnumerable cipherIds, Guid userId); diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index f132588e37..db458a523d 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -113,7 +113,7 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; await _cipherRepository.ReplaceAsync(cipher); await _eventService.LogCipherEventAsync(cipher, Bit.Core.Enums.EventType.Cipher_Updated); @@ -168,7 +168,7 @@ public class CipherService : ICipherService } else { - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); cipher.RevisionDate = DateTime.UtcNow; await ValidateChangeInCollectionsAsync(cipher, collectionIds, savingUserId); await ValidateViewPasswordUserAsync(cipher); @@ -180,8 +180,9 @@ public class CipherService : ICipherService } } - public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment) + public async Task UploadFileForExistingAttachmentAsync(Stream stream, Cipher cipher, CipherAttachment.MetaData attachment, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (attachment == null) { throw new BadRequestException("Cipher attachment does not exist"); @@ -196,8 +197,9 @@ public class CipherService : ICipherService } public async Task<(string attachmentId, string uploadUrl)> CreateAttachmentForDelayedUploadAsync(Cipher cipher, - string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId) + string key, string fileName, long fileSize, bool adminRequest, Guid savingUserId, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, adminRequest, fileSize); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -232,8 +234,9 @@ public class CipherService : ICipherService } public async Task CreateAttachmentAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, Guid savingUserId, bool orgAdmin = false) + long requestLength, Guid savingUserId, bool orgAdmin = false, DateTime? lastKnownRevisionDate = null) { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); await ValidateCipherEditForAttachmentAsync(cipher, savingUserId, orgAdmin, requestLength); var attachmentId = Utilities.CoreHelpers.SecureRandomString(32, upper: false, special: false); @@ -284,10 +287,11 @@ public class CipherService : ICipherService } public async Task CreateAttachmentShareAsync(Cipher cipher, Stream stream, string fileName, string key, - long requestLength, string attachmentId, Guid organizationId) + long requestLength, string attachmentId, Guid organizationId, DateTime? lastKnownRevisionDate = null) { try { + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); if (requestLength < 1) { throw new BadRequestException("No data to attach."); @@ -859,7 +863,7 @@ public class CipherService : ICipherService return NormalCipherPermissions.CanRestore(user, cipher, organizationAbility); } - private void ValidateCipherLastKnownRevisionDateAsync(Cipher cipher, DateTime? lastKnownRevisionDate) + private void ValidateCipherLastKnownRevisionDate(Cipher cipher, DateTime? lastKnownRevisionDate) { if (cipher.Id == default || !lastKnownRevisionDate.HasValue) { @@ -1007,7 +1011,7 @@ public class CipherService : ICipherService throw new BadRequestException("Not enough storage available for this organization."); } - ValidateCipherLastKnownRevisionDateAsync(cipher, lastKnownRevisionDate); + ValidateCipherLastKnownRevisionDate(cipher, lastKnownRevisionDate); } private async Task ValidateViewPasswordUserAsync(Cipher cipher) diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 55db5a9143..95391f1f44 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -113,6 +113,242 @@ public class CipherServiceTests await sutProvider.GetDependency().Received(1).ReplaceAsync(cipherDetails); } + [Theory, BitAutoData] + public async Task CreateAttachmentAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, Guid savingUserId) + { + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + var stream = new MemoryStream(); + var fileName = "test.txt"; + var key = "test-key"; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task CreateAttachmentAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString, + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Mock user storage and premium access + var user = new User { Id = savingUserId, MaxStorageGb = 1 }; + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, Arg.Any(), Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ReplaceAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + await sutProvider.Sut.CreateAttachmentAsync(cipher, stream, fileName, key, 100, savingUserId, false, lastKnownRevisionDate); + + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(Arg.Any(), cipher, Arg.Any()); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentForDelayedUploadAsync_WrongRevisionDate_Throws(SutProvider sutProvider, Cipher cipher, Guid savingUserId) + { + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + var key = "test-key"; + var fileName = "test.txt"; + var fileSize = 100L; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAttachmentForDelayedUploadAsync(cipher, key, fileName, fileSize, false, savingUserId, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task CreateAttachmentForDelayedUploadAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString, + SutProvider sutProvider, CipherDetails cipher, Guid savingUserId) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var key = "test-key"; + var fileName = "test.txt"; + var fileSize = 100L; + + // Setup cipher with user ownership + cipher.UserId = savingUserId; + cipher.OrganizationId = null; + + // Mock user storage and premium access + var user = new User { Id = savingUserId, MaxStorageGb = 1 }; + sutProvider.GetDependency() + .GetByIdAsync(savingUserId) + .Returns(user); + + sutProvider.GetDependency() + .CanAccessPremium(user) + .Returns(true); + + sutProvider.GetDependency() + .GetAttachmentUploadUrlAsync(cipher, Arg.Any()) + .Returns("https://example.com/upload"); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + var result = await sutProvider.Sut.CreateAttachmentForDelayedUploadAsync(cipher, key, fileName, fileSize, false, savingUserId, lastKnownRevisionDate); + + Assert.NotNull(result.attachmentId); + Assert.NotNull(result.uploadUrl); + } + + [Theory, BitAutoData] + public async Task UploadFileForExistingAttachmentAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + Cipher cipher) + { + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + var stream = new MemoryStream(); + var attachment = new CipherAttachment.MetaData + { + AttachmentId = "test-attachment-id", + Size = 100, + FileName = "test.txt", + Key = "test-key" + }; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.UploadFileForExistingAttachmentAsync(stream, cipher, attachment, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task UploadFileForExistingAttachmentAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString, + SutProvider sutProvider, CipherDetails cipher) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var stream = new MemoryStream(new byte[100]); + var attachmentId = "test-attachment-id"; + var attachment = new CipherAttachment.MetaData + { + AttachmentId = attachmentId, + Size = 100, + FileName = "test.txt", + Key = "test-key" + }; + + // Set the attachment on the cipher so ValidateCipherAttachmentFile can find it + cipher.SetAttachments(new Dictionary + { + [attachmentId] = attachment + }); + + sutProvider.GetDependency() + .UploadNewAttachmentAsync(stream, cipher, attachment) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .ValidateFileAsync(cipher, attachment, Arg.Any()) + .Returns((true, 100L)); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + await sutProvider.Sut.UploadFileForExistingAttachmentAsync(stream, cipher, attachment, lastKnownRevisionDate); + + await sutProvider.GetDependency().Received(1) + .UploadNewAttachmentAsync(stream, cipher, attachment); + } + + [Theory, BitAutoData] + public async Task CreateAttachmentShareAsync_WrongRevisionDate_Throws(SutProvider sutProvider, + Cipher cipher, Guid organizationId) + { + var lastKnownRevisionDate = cipher.RevisionDate.AddDays(-1); + var stream = new MemoryStream(); + var fileName = "test.txt"; + var key = "test-key"; + var attachmentId = "attachment-id"; + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.CreateAttachmentShareAsync(cipher, stream, fileName, key, 100, attachmentId, organizationId, lastKnownRevisionDate)); + Assert.Contains("out of date", exception.Message); + } + + [Theory] + [BitAutoData("")] + [BitAutoData("Correct Time")] + public async Task CreateAttachmentShareAsync_CorrectRevisionDate_DoesNotThrow(string revisionDateString, + SutProvider sutProvider, CipherDetails cipher, Guid organizationId) + { + var lastKnownRevisionDate = string.IsNullOrEmpty(revisionDateString) ? (DateTime?)null : cipher.RevisionDate; + var stream = new MemoryStream(new byte[100]); + var fileName = "test.txt"; + var key = "test-key"; + var attachmentId = "attachment-id"; + + // Setup cipher with existing attachment (no TempMetadata) + cipher.OrganizationId = null; + cipher.SetAttachments(new Dictionary + { + [attachmentId] = new CipherAttachment.MetaData + { + AttachmentId = attachmentId, + Size = 100, + FileName = "existing.txt", + Key = "existing-key" + } + }); + + // Mock organization + var organization = new Organization + { + Id = organizationId, + MaxStorageGb = 1 + }; + sutProvider.GetDependency() + .GetByIdAsync(organizationId) + .Returns(organization); + + sutProvider.GetDependency() + .UploadShareAttachmentAsync(stream, cipher.Id, organizationId, Arg.Any()) + .Returns(Task.CompletedTask); + + sutProvider.GetDependency() + .UpdateAttachmentAsync(Arg.Any()) + .Returns(Task.CompletedTask); + + await sutProvider.Sut.CreateAttachmentShareAsync(cipher, stream, fileName, key, 100, attachmentId, organizationId, lastKnownRevisionDate); + + await sutProvider.GetDependency().Received(1) + .UploadShareAttachmentAsync(stream, cipher.Id, organizationId, Arg.Any()); + } + [Theory] [BitAutoData] public async Task SaveDetailsAsync_PersonalVault_WithOrganizationDataOwnershipPolicyEnabled_Throws( From 69f0464e05866bf8425e1aa063be35998decd9ba Mon Sep 17 00:00:00 2001 From: Brant DeBow <125889545+brant-livefront@users.noreply.github.com> Date: Thu, 23 Oct 2025 08:08:09 -0400 Subject: [PATCH 2/3] Refactor Azure Service Bus to use the organization id as a partition key (#6477) * Refactored Azure Service Bus to use the organization id as a partition key * Use null for partition key instead of empty string when organization id is null --- .../EventIntegrations/IIntegrationMessage.cs | 1 + .../EventIntegrations/IntegrationMessage.cs | 1 + .../Services/IEventIntegrationPublisher.cs | 2 +- .../AzureServiceBusService.cs | 11 +++++++---- .../EventIntegrationEventWriteService.cs | 14 ++++++++++---- .../EventIntegrationHandler.cs | 1 + .../EventIntegrations/RabbitMqService.cs | 2 +- .../IntegrationMessageTests.cs | 6 ++++++ .../EventIntegrationEventWriteServiceTests.cs | 8 +++++--- .../Services/EventIntegrationHandlerTests.cs | 18 +++++++++++++----- .../Services/IntegrationHandlerTests.cs | 3 +++ 11 files changed, 49 insertions(+), 18 deletions(-) diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs index 7a0962d89a..5b6bfe2e53 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IIntegrationMessage.cs @@ -6,6 +6,7 @@ public interface IIntegrationMessage { IntegrationType IntegrationType { get; } string MessageId { get; set; } + string? OrganizationId { get; set; } int RetryCount { get; } DateTime? DelayUntilDate { get; } void ApplyRetry(DateTime? handlerDelayUntilDate); diff --git a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs index 11a5229f8c..b0fc2161ba 100644 --- a/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs +++ b/src/Core/AdminConsole/Models/Data/EventIntegrations/IntegrationMessage.cs @@ -7,6 +7,7 @@ public class IntegrationMessage : IIntegrationMessage { public IntegrationType IntegrationType { get; set; } public required string MessageId { get; set; } + public string? OrganizationId { get; set; } public required string RenderedTemplate { get; set; } public int RetryCount { get; set; } = 0; public DateTime? DelayUntilDate { get; set; } diff --git a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs b/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs index b80b518223..4d95707e90 100644 --- a/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs +++ b/src/Core/AdminConsole/Services/IEventIntegrationPublisher.cs @@ -5,5 +5,5 @@ namespace Bit.Core.Services; public interface IEventIntegrationPublisher : IAsyncDisposable { Task PublishAsync(IIntegrationMessage message); - Task PublishEventAsync(string body); + Task PublishEventAsync(string body, string? organizationId); } diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs index 4887aa3a7f..953a9bb56e 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/AzureServiceBusService.cs @@ -30,7 +30,8 @@ public class AzureServiceBusService : IAzureServiceBusService var serviceBusMessage = new ServiceBusMessage(json) { Subject = message.IntegrationType.ToRoutingKey(), - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); @@ -44,18 +45,20 @@ public class AzureServiceBusService : IAzureServiceBusService { Subject = message.IntegrationType.ToRoutingKey(), ScheduledEnqueueTime = message.DelayUntilDate ?? DateTime.UtcNow, - MessageId = message.MessageId + MessageId = message.MessageId, + PartitionKey = message.OrganizationId }; await _integrationSender.SendMessageAsync(serviceBusMessage); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { var message = new ServiceBusMessage(body) { ContentType = "application/json", - MessageId = Guid.NewGuid().ToString() + MessageId = Guid.NewGuid().ToString(), + PartitionKey = organizationId }; await _eventSender.SendMessageAsync(message); diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs index 309b4a8409..4ac97df763 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationEventWriteService.cs @@ -14,15 +14,21 @@ public class EventIntegrationEventWriteService : IEventWriteService, IAsyncDispo public async Task CreateAsync(IEvent e) { var body = JsonSerializer.Serialize(e); - await _eventIntegrationPublisher.PublishEventAsync(body: body); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: e.OrganizationId?.ToString()); } public async Task CreateManyAsync(IEnumerable events) { - var body = JsonSerializer.Serialize(events); - await _eventIntegrationPublisher.PublishEventAsync(body: body); - } + var eventList = events as IList ?? events.ToList(); + if (eventList.Count == 0) + { + return; + } + var organizationId = eventList[0].OrganizationId?.ToString(); + var body = JsonSerializer.Serialize(eventList); + await _eventIntegrationPublisher.PublishEventAsync(body: body, organizationId: organizationId); + } public async ValueTask DisposeAsync() { await _eventIntegrationPublisher.DisposeAsync(); diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs index 0a8ab67554..8423652eb8 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/EventIntegrationHandler.cs @@ -57,6 +57,7 @@ public class EventIntegrationHandler( { IntegrationType = integrationType, MessageId = messageId.ToString(), + OrganizationId = organizationId.ToString(), Configuration = config, RenderedTemplate = renderedTemplate, RetryCount = 0, diff --git a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs index 3e20e34200..8976530cf4 100644 --- a/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs +++ b/src/Core/AdminConsole/Services/Implementations/EventIntegrations/RabbitMqService.cs @@ -122,7 +122,7 @@ public class RabbitMqService : IRabbitMqService body: body); } - public async Task PublishEventAsync(string body) + public async Task PublishEventAsync(string body, string? organizationId) { await using var channel = await CreateChannelAsync(); var properties = new BasicProperties diff --git a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs index edd5cd488f..71f9a15037 100644 --- a/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs +++ b/test/Core.Test/AdminConsole/Models/Data/EventIntegrations/IntegrationMessageTests.cs @@ -8,6 +8,7 @@ namespace Bit.Core.Test.Models.Data.EventIntegrations; public class IntegrationMessageTests { private const string _messageId = "TestMessageId"; + private const string _organizationId = "TestOrganizationId"; [Fact] public void ApplyRetry_IncrementsRetryCountAndSetsDelayUntilDate() @@ -16,6 +17,7 @@ public class IntegrationMessageTests { Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), MessageId = _messageId, + OrganizationId = _organizationId, RetryCount = 2, RenderedTemplate = string.Empty, DelayUntilDate = null @@ -36,6 +38,7 @@ public class IntegrationMessageTests { Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), MessageId = _messageId, + OrganizationId = _organizationId, RenderedTemplate = "This is the message", IntegrationType = IntegrationType.Webhook, RetryCount = 2, @@ -48,6 +51,7 @@ public class IntegrationMessageTests Assert.NotNull(result); Assert.Equal(message.Configuration, result.Configuration); Assert.Equal(message.MessageId, result.MessageId); + Assert.Equal(message.OrganizationId, result.OrganizationId); Assert.Equal(message.RenderedTemplate, result.RenderedTemplate); Assert.Equal(message.IntegrationType, result.IntegrationType); Assert.Equal(message.RetryCount, result.RetryCount); @@ -67,6 +71,7 @@ public class IntegrationMessageTests var message = new IntegrationMessage { MessageId = _messageId, + OrganizationId = _organizationId, RenderedTemplate = "This is the message", IntegrationType = IntegrationType.Webhook, RetryCount = 2, @@ -77,6 +82,7 @@ public class IntegrationMessageTests var result = JsonSerializer.Deserialize(json); Assert.Equal(message.MessageId, result.MessageId); + Assert.Equal(message.OrganizationId, result.OrganizationId); Assert.Equal(message.RenderedTemplate, result.RenderedTemplate); Assert.Equal(message.IntegrationType, result.IntegrationType); Assert.Equal(message.RetryCount, result.RetryCount); diff --git a/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs b/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs index 9369690d86..03f9c7764d 100644 --- a/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/EventIntegrationEventWriteServiceTests.cs @@ -22,18 +22,20 @@ public class EventIntegrationEventWriteServiceTests [Theory, BitAutoData] public async Task CreateAsync_EventPublishedToEventQueue(EventMessage eventMessage) { - var expected = JsonSerializer.Serialize(eventMessage); await Subject.CreateAsync(eventMessage); await _eventIntegrationPublisher.Received(1).PublishEventAsync( - Arg.Is(body => AssertJsonStringsMatch(eventMessage, body))); + body: Arg.Is(body => AssertJsonStringsMatch(eventMessage, body)), + organizationId: Arg.Is(orgId => eventMessage.OrganizationId.ToString().Equals(orgId))); } [Theory, BitAutoData] public async Task CreateManyAsync_EventsPublishedToEventQueue(IEnumerable eventMessages) { + var eventMessage = eventMessages.First(); await Subject.CreateManyAsync(eventMessages); await _eventIntegrationPublisher.Received(1).PublishEventAsync( - Arg.Is(body => AssertJsonStringsMatch(eventMessages, body))); + body: Arg.Is(body => AssertJsonStringsMatch(eventMessages, body)), + organizationId: Arg.Is(orgId => eventMessage.OrganizationId.ToString().Equals(orgId))); } private static bool AssertJsonStringsMatch(EventMessage expected, string body) diff --git a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs index f038fe28ef..89207a9d3a 100644 --- a/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs +++ b/test/Core.Test/AdminConsole/Services/EventIntegrationHandlerTests.cs @@ -23,6 +23,7 @@ public class EventIntegrationHandlerTests private const string _templateWithOrganization = "Org: #OrganizationName#"; private const string _templateWithUser = "#UserName#, #UserEmail#"; private const string _templateWithActingUser = "#ActingUserName#, #ActingUserEmail#"; + private static readonly Guid _organizationId = Guid.NewGuid(); private static readonly Uri _uri = new Uri("https://localhost"); private static readonly Uri _uri2 = new Uri("https://example.com"); private readonly IEventIntegrationPublisher _eventIntegrationPublisher = Substitute.For(); @@ -50,6 +51,7 @@ public class EventIntegrationHandlerTests { IntegrationType = IntegrationType.Webhook, MessageId = "TestMessageId", + OrganizationId = _organizationId.ToString(), Configuration = new WebhookIntegrationConfigurationDetails(_uri), RenderedTemplate = template, RetryCount = 0, @@ -122,6 +124,7 @@ public class EventIntegrationHandlerTests public async Task HandleEventAsync_BaseTemplateOneConfiguration_PublishesIntegrationMessage(EventMessage eventMessage) { var sutProvider = GetSutProvider(OneConfiguration(_templateBase)); + eventMessage.OrganizationId = _organizationId; await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -140,6 +143,7 @@ public class EventIntegrationHandlerTests public async Task HandleEventAsync_BaseTemplateTwoConfigurations_PublishesIntegrationMessages(EventMessage eventMessage) { var sutProvider = GetSutProvider(TwoConfigurations(_templateBase)); + eventMessage.OrganizationId = _organizationId; await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -164,6 +168,7 @@ public class EventIntegrationHandlerTests var user = Substitute.For(); user.Email = "test@example.com"; user.Name = "Test"; + eventMessage.OrganizationId = _organizationId; sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -183,6 +188,7 @@ public class EventIntegrationHandlerTests var sutProvider = GetSutProvider(OneConfiguration(_templateWithOrganization)); var organization = Substitute.For(); organization.Name = "Test"; + eventMessage.OrganizationId = _organizationId; sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(organization); await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -205,6 +211,7 @@ public class EventIntegrationHandlerTests var user = Substitute.For(); user.Email = "test@example.com"; user.Name = "Test"; + eventMessage.OrganizationId = _organizationId; sutProvider.GetDependency().GetByIdAsync(Arg.Any()).Returns(user); await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -235,6 +242,7 @@ public class EventIntegrationHandlerTests var sutProvider = GetSutProvider(ValidFilterConfiguration()); sutProvider.GetDependency().EvaluateFilterGroup( Arg.Any(), Arg.Any()).Returns(true); + eventMessage.OrganizationId = _organizationId; await sutProvider.Sut.HandleEventAsync(eventMessage); @@ -284,7 +292,7 @@ public class EventIntegrationHandlerTests $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId", "OrganizationId" }))); } } @@ -301,12 +309,12 @@ public class EventIntegrationHandlerTests var expectedMessage = EventIntegrationHandlerTests.expectedMessage( $"Date: {eventMessage.Date}, Type: {eventMessage.Type}, UserId: {eventMessage.UserId}" ); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual( + expectedMessage, new[] { "MessageId", "OrganizationId" }))); expectedMessage.Configuration = new WebhookIntegrationConfigurationDetails(_uri2); - await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is( - AssertHelper.AssertPropertyEqual(expectedMessage, new[] { "MessageId" }))); + await _eventIntegrationPublisher.Received(1).PublishAsync(Arg.Is(AssertHelper.AssertPropertyEqual( + expectedMessage, new[] { "MessageId", "OrganizationId" }))); } } } diff --git a/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs b/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs index aa93567538..f6f587cfd7 100644 --- a/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs +++ b/test/Core.Test/AdminConsole/Services/IntegrationHandlerTests.cs @@ -16,6 +16,7 @@ public class IntegrationHandlerTests { Configuration = new WebhookIntegrationConfigurationDetails(new Uri("https://localhost"), "Bearer", "AUTH-TOKEN"), MessageId = "TestMessageId", + OrganizationId = "TestOrganizationId", IntegrationType = IntegrationType.Webhook, RenderedTemplate = "Template", DelayUntilDate = null, @@ -25,6 +26,8 @@ public class IntegrationHandlerTests var result = await sut.HandleAsync(expected.ToJson()); var typedResult = Assert.IsType>(result.Message); + Assert.Equal(expected.MessageId, typedResult.MessageId); + Assert.Equal(expected.OrganizationId, typedResult.OrganizationId); Assert.Equal(expected.Configuration, typedResult.Configuration); Assert.Equal(expected.RenderedTemplate, typedResult.RenderedTemplate); Assert.Equal(expected.IntegrationType, typedResult.IntegrationType); From dd1f0a120a393624b05f4467a196a464f4adb6f7 Mon Sep 17 00:00:00 2001 From: Maciej Zieniuk <167752252+mzieniukbw@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:40:57 +0200 Subject: [PATCH 3/3] Notifications service unit test coverage with small refactor (#6126) --- src/Notifications/AzureQueueHostedService.cs | 44 +-- .../Controllers/SendController.cs | 28 +- src/Notifications/HubHelpers.cs | 124 ++++++--- src/Notifications/Startup.cs | 1 + test/Notifications.Test/HubHelpersTest.cs | 250 ++++++++++++++++++ .../Notifications.Test.csproj | 2 + 6 files changed, 380 insertions(+), 69 deletions(-) create mode 100644 test/Notifications.Test/HubHelpersTest.cs diff --git a/src/Notifications/AzureQueueHostedService.cs b/src/Notifications/AzureQueueHostedService.cs index 94aa14eaf6..40dd8d22d4 100644 --- a/src/Notifications/AzureQueueHostedService.cs +++ b/src/Notifications/AzureQueueHostedService.cs @@ -1,34 +1,26 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using Azure.Storage.Queues; +using Azure.Storage.Queues; using Bit.Core.Settings; using Bit.Core.Utilities; -using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; public class AzureQueueHostedService : IHostedService, IDisposable { private readonly ILogger _logger; - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; + private readonly HubHelpers _hubHelpers; private readonly GlobalSettings _globalSettings; - private Task _executingTask; - private CancellationTokenSource _cts; - private QueueClient _queueClient; + private Task? _executingTask; + private CancellationTokenSource? _cts; public AzureQueueHostedService( ILogger logger, - IHubContext hubContext, - IHubContext anonymousHubContext, + HubHelpers hubHelpers, GlobalSettings globalSettings) { _logger = logger; - _hubContext = hubContext; + _hubHelpers = hubHelpers; _globalSettings = globalSettings; - _anonymousHubContext = anonymousHubContext; } public Task StartAsync(CancellationToken cancellationToken) @@ -44,32 +36,39 @@ public class AzureQueueHostedService : IHostedService, IDisposable { return; } + _logger.LogWarning("Stopping service."); - _cts.Cancel(); + _cts?.Cancel(); await Task.WhenAny(_executingTask, Task.Delay(-1, cancellationToken)); cancellationToken.ThrowIfCancellationRequested(); } public void Dispose() - { } + { + } private async Task ExecuteAsync(CancellationToken cancellationToken) { - _queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); + var queueClient = new QueueClient(_globalSettings.Notifications.ConnectionString, "notifications"); while (!cancellationToken.IsCancellationRequested) { try { - var messages = await _queueClient.ReceiveMessagesAsync(32); + var messages = await queueClient.ReceiveMessagesAsync(32, cancellationToken: cancellationToken); if (messages.Value?.Any() ?? false) { foreach (var message in messages.Value) { try { - await HubHelpers.SendNotificationToHubAsync( - message.DecodeMessageText(), _hubContext, _anonymousHubContext, _logger, cancellationToken); - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + var decodedMessage = message.DecodeMessageText(); + if (!string.IsNullOrWhiteSpace(decodedMessage)) + { + await _hubHelpers.SendNotificationToHubAsync(decodedMessage, cancellationToken); + } + + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } catch (Exception e) { @@ -77,7 +76,8 @@ public class AzureQueueHostedService : IHostedService, IDisposable message.MessageId, message.DequeueCount); if (message.DequeueCount > 2) { - await _queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt); + await queueClient.DeleteMessageAsync(message.MessageId, message.PopReceipt, + cancellationToken); } } } diff --git a/src/Notifications/Controllers/SendController.cs b/src/Notifications/Controllers/SendController.cs index 7debd51df7..c663102b56 100644 --- a/src/Notifications/Controllers/SendController.cs +++ b/src/Notifications/Controllers/SendController.cs @@ -1,36 +1,30 @@ -using System.Text; +#nullable enable +using System.Text; using Bit.Core.Utilities; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.SignalR; -namespace Bit.Notifications; +namespace Bit.Notifications.Controllers; [Authorize("Internal")] public class SendController : Controller { - private readonly IHubContext _hubContext; - private readonly IHubContext _anonymousHubContext; - private readonly ILogger _logger; + private readonly HubHelpers _hubHelpers; - public SendController(IHubContext hubContext, IHubContext anonymousHubContext, ILogger logger) + public SendController(HubHelpers hubHelpers) { - _hubContext = hubContext; - _anonymousHubContext = anonymousHubContext; - _logger = logger; + _hubHelpers = hubHelpers; } [HttpPost("~/send")] [SelfHosted(SelfHostedOnly = true)] - public async Task PostSend() + public async Task PostSendAsync() { - using (var reader = new StreamReader(Request.Body, Encoding.UTF8)) + using var reader = new StreamReader(Request.Body, Encoding.UTF8); + var notificationJson = await reader.ReadToEndAsync(); + if (!string.IsNullOrWhiteSpace(notificationJson)) { - var notificationJson = await reader.ReadToEndAsync(); - if (!string.IsNullOrWhiteSpace(notificationJson)) - { - await HubHelpers.SendNotificationToHubAsync(notificationJson, _hubContext, _anonymousHubContext, _logger); - } + await _hubHelpers.SendNotificationToHubAsync(notificationJson); } } } diff --git a/src/Notifications/HubHelpers.cs b/src/Notifications/HubHelpers.cs index 0fea72edc3..2ef674adfe 100644 --- a/src/Notifications/HubHelpers.cs +++ b/src/Notifications/HubHelpers.cs @@ -1,31 +1,39 @@ -// FIXME: Update this file to be null safe and then delete the line below -#nullable disable - -using System.Text.Json; +using System.Text.Json; using Bit.Core.Enums; using Bit.Core.Models; using Microsoft.AspNetCore.SignalR; namespace Bit.Notifications; -public static class HubHelpers +public class HubHelpers { - private static JsonSerializerOptions _deserializerOptions = - new JsonSerializerOptions { PropertyNameCaseInsensitive = true }; + private static readonly JsonSerializerOptions _deserializerOptions = new() { PropertyNameCaseInsensitive = true }; private static readonly string _receiveMessageMethod = "ReceiveMessage"; - public static async Task SendNotificationToHubAsync( - string notificationJson, - IHubContext hubContext, + private readonly IHubContext _hubContext; + private readonly IHubContext _anonymousHubContext; + private readonly ILogger _logger; + + public HubHelpers(IHubContext hubContext, IHubContext anonymousHubContext, - ILogger logger, - CancellationToken cancellationToken = default(CancellationToken) - ) + ILogger logger) + { + _hubContext = hubContext; + _anonymousHubContext = anonymousHubContext; + _logger = logger; + } + + public async Task SendNotificationToHubAsync(string notificationJson, CancellationToken cancellationToken = default) { var notification = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - logger.LogInformation("Sending notification: {NotificationType}", notification.Type); + if (notification is null) + { + return; + } + + _logger.LogInformation("Sending notification: {NotificationType}", notification.Type); switch (notification.Type) { case PushType.SyncCipherUpdate: @@ -35,14 +43,19 @@ public static class HubHelpers var cipherNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (cipherNotification is null) + { + break; + } + if (cipherNotification.Payload.UserId.HasValue) { - await hubContext.Clients.User(cipherNotification.Payload.UserId.ToString()) + await _hubContext.Clients.User(cipherNotification.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } else if (cipherNotification.Payload.OrganizationId.HasValue) { - await hubContext.Clients + await _hubContext.Clients .Group(NotificationsHub.GetOrganizationGroup(cipherNotification.Payload.OrganizationId.Value)) .SendAsync(_receiveMessageMethod, cipherNotification, cancellationToken); } @@ -54,7 +67,12 @@ public static class HubHelpers var folderNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) + if (folderNotification is null) + { + break; + } + + await _hubContext.Clients.User(folderNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, folderNotification, cancellationToken); break; case PushType.SyncCiphers: @@ -66,7 +84,12 @@ public static class HubHelpers var userNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(userNotification.Payload.UserId.ToString()) + if (userNotification is null) + { + break; + } + + await _hubContext.Clients.User(userNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, userNotification, cancellationToken); break; case PushType.SyncSendCreate: @@ -75,36 +98,65 @@ public static class HubHelpers var sendNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) + if (sendNotification is null) + { + break; + } + + await _hubContext.Clients.User(sendNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, sendNotification, cancellationToken); break; case PushType.AuthRequestResponse: var authRequestResponseNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) + if (authRequestResponseNotification is null) + { + break; + } + + await _anonymousHubContext.Clients.Group(authRequestResponseNotification.Payload.Id.ToString()) .SendAsync("AuthRequestResponseRecieved", authRequestResponseNotification, cancellationToken); break; case PushType.AuthRequest: var authRequestNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) + if (authRequestNotification is null) + { + break; + } + + await _hubContext.Clients.User(authRequestNotification.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, authRequestNotification, cancellationToken); break; case PushType.SyncOrganizationStatusChanged: var orgStatusNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) + if (orgStatusNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(orgStatusNotification.Payload.OrganizationId)) .SendAsync(_receiveMessageMethod, orgStatusNotification, cancellationToken); break; case PushType.SyncOrganizationCollectionSettingChanged: var organizationCollectionSettingsChangedNotification = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification.Payload.OrganizationId)) - .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, cancellationToken); + if (organizationCollectionSettingsChangedNotification is null) + { + break; + } + + await _hubContext.Clients + .Group(NotificationsHub.GetOrganizationGroup(organizationCollectionSettingsChangedNotification + .Payload.OrganizationId)) + .SendAsync(_receiveMessageMethod, organizationCollectionSettingsChangedNotification, + cancellationToken); break; case PushType.OrganizationBankAccountVerified: var organizationBankAccountVerifiedNotification = @@ -124,9 +176,14 @@ public static class HubHelpers case PushType.NotificationStatus: var notificationData = JsonSerializer.Deserialize>( notificationJson, _deserializerOptions); + if (notificationData is null) + { + break; + } + if (notificationData.Payload.InstallationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetInstallationGroup( notificationData.Payload.InstallationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } @@ -134,27 +191,34 @@ public static class HubHelpers { if (notificationData.Payload.ClientType == ClientType.All) { - await hubContext.Clients.User(notificationData.Payload.UserId.ToString()) + await _hubContext.Clients.User(notificationData.Payload.UserId.Value.ToString()) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } else { - await hubContext.Clients.Group(NotificationsHub.GetUserGroup( + await _hubContext.Clients.Group(NotificationsHub.GetUserGroup( notificationData.Payload.UserId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } } else if (notificationData.Payload.OrganizationId.HasValue) { - await hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( + await _hubContext.Clients.Group(NotificationsHub.GetOrganizationGroup( notificationData.Payload.OrganizationId.Value, notificationData.Payload.ClientType)) .SendAsync(_receiveMessageMethod, notificationData, cancellationToken); } break; case PushType.RefreshSecurityTasks: - var pendingTasksData = JsonSerializer.Deserialize>(notificationJson, _deserializerOptions); - await hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) + var pendingTasksData = + JsonSerializer.Deserialize>(notificationJson, + _deserializerOptions); + if (pendingTasksData is null) + { + break; + } + + await _hubContext.Clients.User(pendingTasksData.Payload.UserId.ToString()) .SendAsync(_receiveMessageMethod, pendingTasksData, cancellationToken); break; default: diff --git a/src/Notifications/Startup.cs b/src/Notifications/Startup.cs index eb3c3f8682..2889e90d3b 100644 --- a/src/Notifications/Startup.cs +++ b/src/Notifications/Startup.cs @@ -61,6 +61,7 @@ public class Startup } services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); // Mvc services.AddMvc(); diff --git a/test/Notifications.Test/HubHelpersTest.cs b/test/Notifications.Test/HubHelpersTest.cs new file mode 100644 index 0000000000..df4d3c5f85 --- /dev/null +++ b/test/Notifications.Test/HubHelpersTest.cs @@ -0,0 +1,250 @@ +#nullable enable +using System.Text.Json; +using Bit.Core.Enums; +using Bit.Core.Models; +using Bit.Core.Test.NotificationCenter.AutoFixture; +using Bit.Core.Utilities; +using Bit.Notifications; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.AspNetCore.SignalR; +using NSubstitute; + +namespace Notifications.Test; + +[SutProviderCustomize] +[NotificationCustomize(false)] +public class HubHelpersTest +{ + [Theory] + [BitAutoData] + public async Task SendNotificationToHubAsync_NotificationPushNotificationGlobal_NothingSent( + SutProvider sutProvider, + NotificationPushNotification notification, + string contextId, CancellationToken cancellationToke) + { + notification.Global = true; + notification.InstallationId = null; + notification.UserId = null; + notification.OrganizationId = null; + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToke); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).Group(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task + SendNotificationToHubAsync_NotificationPushNotificationInstallationIdProvidedClientTypeAll_SentToGroupInstallation( + SutProvider sutProvider, + NotificationPushNotification notification, + string contextId, CancellationToken cancellationToken) + { + notification.UserId = null; + notification.OrganizationId = null; + notification.ClientType = ClientType.All; + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"Installation_{notification.InstallationId!.Value.ToString()}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Mobile)] + [BitAutoData(ClientType.Web)] + public async Task + SendNotificationToHubAsync_NotificationPushNotificationInstallationIdProvidedClientTypeNotAll_SentToGroupInstallationClientType( + ClientType clientType, SutProvider sutProvider, + NotificationPushNotification notification, + string contextId, CancellationToken cancellationToken) + { + notification.UserId = null; + notification.OrganizationId = null; + notification.ClientType = clientType; + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"Installation_ClientType_{notification.InstallationId!.Value}_{clientType}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(false)] + [BitAutoData(true)] + public async Task SendNotificationToHubAsync_NotificationPushNotificationUserIdProvidedClientTypeAll_SentToUser( + bool organizationIdProvided, SutProvider sutProvider, + NotificationPushNotification notification, + string contextId, CancellationToken cancellationToken) + { + notification.InstallationId = null; + notification.ClientType = ClientType.All; + if (!organizationIdProvided) + { + notification.OrganizationId = null; + } + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + await sutProvider.GetDependency>().Clients.Received(1) + .User(notification.UserId!.Value.ToString()) + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).Group(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(false, ClientType.Browser)] + [BitAutoData(false, ClientType.Desktop)] + [BitAutoData(false, ClientType.Mobile)] + [BitAutoData(false, ClientType.Web)] + [BitAutoData(true, ClientType.Browser)] + [BitAutoData(true, ClientType.Desktop)] + [BitAutoData(true, ClientType.Mobile)] + [BitAutoData(true, ClientType.Web)] + public async Task + SendNotificationToHubAsync_NotificationPushNotificationUserIdProvidedClientTypeNotAll_SentToGroupUserClientType( + bool organizationIdProvided, ClientType clientType, SutProvider sutProvider, + NotificationPushNotification notification, + string contextId, CancellationToken cancellationToken) + { + notification.InstallationId = null; + notification.ClientType = clientType; + if (!organizationIdProvided) + { + notification.OrganizationId = null; + } + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"UserClientType_{notification.UserId!.Value}_{clientType}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData] + public async Task + SendNotificationToHubAsync_NotificationPushNotificationOrganizationIdProvidedClientTypeAll_SentToGroupOrganization( + SutProvider sutProvider, string contextId, + NotificationPushNotification notification, + CancellationToken cancellationToken) + { + notification.UserId = null; + notification.InstallationId = null; + notification.ClientType = ClientType.All; + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"Organization_{notification.OrganizationId!.Value}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + [Theory] + [BitAutoData(ClientType.Browser)] + [BitAutoData(ClientType.Desktop)] + [BitAutoData(ClientType.Mobile)] + [BitAutoData(ClientType.Web)] + public async Task + SendNotificationToHubAsync_NotificationPushNotificationOrganizationIdProvidedClientTypeNotAll_SentToGroupOrganizationClientType( + ClientType clientType, SutProvider sutProvider, string contextId, + NotificationPushNotification notification, + CancellationToken cancellationToken) + { + notification.UserId = null; + notification.InstallationId = null; + notification.ClientType = clientType; + + var json = ToNotificationJson(notification, PushType.Notification, contextId); + await sutProvider.Sut.SendNotificationToHubAsync(json, cancellationToken); + + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + await sutProvider.GetDependency>().Clients.Received(1) + .Group($"OrganizationClientType_{notification.OrganizationId!.Value}_{clientType}") + .Received(1) + .SendCoreAsync("ReceiveMessage", Arg.Is(objects => + objects.Length == 1 && IsNotificationPushNotificationEqual(notification, objects[0], + PushType.Notification, contextId)), + cancellationToken); + sutProvider.GetDependency>().Clients.Received(0).User(Arg.Any()); + sutProvider.GetDependency>().Clients.Received(0) + .Group(Arg.Any()); + } + + private static string ToNotificationJson(object payload, PushType type, string contextId) + { + var notification = new PushNotificationData(type, payload, contextId); + return JsonSerializer.Serialize(notification, JsonHelpers.IgnoreWritingNull); + } + + private static bool IsNotificationPushNotificationEqual(NotificationPushNotification expected, object? actual, + PushType type, string contextId) + { + if (actual is not PushNotificationData pushNotificationData) + { + return false; + } + + return pushNotificationData.Type == type && + pushNotificationData.ContextId == contextId && + expected.Id == pushNotificationData.Payload.Id && + expected.UserId == pushNotificationData.Payload.UserId && + expected.OrganizationId == pushNotificationData.Payload.OrganizationId && + expected.ClientType == pushNotificationData.Payload.ClientType && + expected.RevisionDate == pushNotificationData.Payload.RevisionDate; + } +} diff --git a/test/Notifications.Test/Notifications.Test.csproj b/test/Notifications.Test/Notifications.Test.csproj index 4dd37605c2..a4bab9df98 100644 --- a/test/Notifications.Test/Notifications.Test.csproj +++ b/test/Notifications.Test/Notifications.Test.csproj @@ -18,5 +18,7 @@ + +