diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 9728c2e727..75e0c78702 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -26,6 +26,7 @@ using Bit.Core.Vault.Models.Data; using Core.Auth.Enums; using HandlebarsDotNet; using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; namespace Bit.Core.Services; @@ -39,6 +40,7 @@ public class HandlebarsMailService : IMailService private readonly IMailDeliveryService _mailDeliveryService; private readonly IMailEnqueuingService _mailEnqueuingService; private readonly IDistributedCache _distributedCache; + private readonly ILogger _logger; private readonly Dictionary> _templateCache = new(); private bool _registeredHelpersAndPartials = false; @@ -47,12 +49,14 @@ public class HandlebarsMailService : IMailService GlobalSettings globalSettings, IMailDeliveryService mailDeliveryService, IMailEnqueuingService mailEnqueuingService, - IDistributedCache distributedCache) + IDistributedCache distributedCache, + ILogger logger) { _globalSettings = globalSettings; _mailDeliveryService = mailDeliveryService; _mailEnqueuingService = mailEnqueuingService; _distributedCache = distributedCache; + _logger = logger; } public async Task SendVerifyEmailEmailAsync(string email, Guid userId, string token) @@ -708,6 +712,12 @@ public class HandlebarsMailService : IMailService private async Task ReadSourceAsync(string templateName) { + var diskSource = await ReadSourceFromDiskAsync(templateName); + if (!string.IsNullOrWhiteSpace(diskSource)) + { + return diskSource; + } + var assembly = typeof(HandlebarsMailService).GetTypeInfo().Assembly; var fullTemplateName = $"{Namespace}.{templateName}.hbs"; if (!assembly.GetManifestResourceNames().Any(f => f == fullTemplateName)) @@ -721,6 +731,42 @@ public class HandlebarsMailService : IMailService } } + private async Task ReadSourceFromDiskAsync(string templateName) + { + if (!_globalSettings.SelfHosted) + { + return null; + } + try + { + var templateFileSuffix = ".html"; + if (templateName.EndsWith(".txt")) + { + templateFileSuffix = ".txt"; + } + else if (!templateName.EndsWith(".html")) + { + // unexpected suffix + return null; + } + var suffixPosition = templateName.LastIndexOf(templateFileSuffix); + var templateNameNoSuffix = templateName.Substring(0, suffixPosition); + var templatePathNoSuffix = templateNameNoSuffix.Replace(".", "/"); + var diskPath = $"{_globalSettings.MailTemplateDirectory}/{templatePathNoSuffix}{templateFileSuffix}.hbs"; + var directory = Path.GetDirectoryName(diskPath); + if (Directory.Exists(directory) && File.Exists(diskPath)) + { + var fileContents = await File.ReadAllTextAsync(diskPath); + return fileContents; + } + } + catch (Exception e) + { + _logger.LogError(e, "Failed to read mail template from disk."); + } + return null; + } + private async Task RegisterHelpersAndPartialsAsync() { if (_registeredHelpersAndPartials) diff --git a/src/Core/Settings/GlobalSettings.cs b/src/Core/Settings/GlobalSettings.cs index 2a5b5128b2..f045570df5 100644 --- a/src/Core/Settings/GlobalSettings.cs +++ b/src/Core/Settings/GlobalSettings.cs @@ -8,6 +8,7 @@ namespace Bit.Core.Settings; public class GlobalSettings : IGlobalSettings { + private string _mailTemplateDirectory; private string _logDirectory; private string _licenseDirectory; @@ -37,6 +38,11 @@ public class GlobalSettings : IGlobalSettings get => BuildDirectory(_licenseDirectory, "/core/licenses"); set => _licenseDirectory = value; } + public virtual string MailTemplateDirectory + { + get => BuildDirectory(_mailTemplateDirectory, "/mail-templates"); + set => _mailTemplateDirectory = value; + } public string LicenseCertificatePassword { get; set; } public virtual string PushRelayBaseUri { get; set; } public virtual string InternalIdentityKey { get; set; } diff --git a/test/Core.Test/Services/HandlebarsMailServiceTests.cs b/test/Core.Test/Services/HandlebarsMailServiceTests.cs index 242bcc60f3..30eebfb30f 100644 --- a/test/Core.Test/Services/HandlebarsMailServiceTests.cs +++ b/test/Core.Test/Services/HandlebarsMailServiceTests.cs @@ -23,6 +23,7 @@ public class HandlebarsMailServiceTests private readonly IMailDeliveryService _mailDeliveryService; private readonly IMailEnqueuingService _mailEnqueuingService; private readonly IDistributedCache _distributedCache; + private readonly ILogger _logger; public HandlebarsMailServiceTests() { @@ -30,12 +31,14 @@ public class HandlebarsMailServiceTests _mailDeliveryService = Substitute.For(); _mailEnqueuingService = Substitute.For(); _distributedCache = Substitute.For(); + _logger = Substitute.For>(); _sut = new HandlebarsMailService( _globalSettings, _mailDeliveryService, _mailEnqueuingService, - _distributedCache + _distributedCache, + _logger ); } @@ -217,8 +220,9 @@ public class HandlebarsMailServiceTests var mailDeliveryService = new MailKitSmtpMailDeliveryService(globalSettings, Substitute.For>()); var distributedCache = Substitute.For(); + var logger = Substitute.For>(); - var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService(), distributedCache); + var handlebarsService = new HandlebarsMailService(globalSettings, mailDeliveryService, new BlockingMailEnqueuingService(), distributedCache, logger); var sendMethods = typeof(IMailService).GetMethods(BindingFlags.Public | BindingFlags.Instance) .Where(m => m.Name.StartsWith("Send") && m.Name != "SendEnqueuedMailMessageAsync");